Browse Source

Multithreading with Keras. #890

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
029779f1da
14 changed files with 205 additions and 240 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.graph.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Contexts/Context.cs
  3. +5
    -12
      src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
  4. +1
    -1
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs
  7. +28
    -85
      src/TensorFlowNET.Core/ops.threading.cs
  8. +10
    -7
      src/TensorFlowNET.Core/tensorflow.cs
  9. +0
    -53
      src/TensorFlowNET.Core/tensorflow.threading.cs
  10. +1
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  11. +8
    -7
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs
  12. +52
    -67
      test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs
  13. +95
    -0
      test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs
  14. +0
    -3
      test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.graph.cs View File

@@ -20,8 +20,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
public graph_util_impl graph_util => new graph_util_impl();
public GraphTransformer graph_transforms => new GraphTransformer();
public graph_util_impl graph_util { get; } = new graph_util_impl();
public GraphTransformer graph_transforms { get; } = new GraphTransformer();
public GraphKeys GraphKeys { get; } = new GraphKeys();

public void reset_default_graph()


+ 1
- 1
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -171,7 +171,7 @@ namespace Tensorflow.Contexts

public void reset_context()
{
ops.reset_uid();
// ops.reset_uid();
// tf.defaultSession = null;
ops.reset_default_graph();
context_switches.Clear();


+ 5
- 12
src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs View File

@@ -14,9 +14,8 @@
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -25,19 +24,14 @@ namespace Tensorflow
/// </summary>
public class DefaultGraphStack
{
private readonly Stack<Graph> _stack = new Stack<Graph>();
Graph _global_default_graph;
Stack<Graph> _stack = new Stack<Graph>();

public Graph get_default()
{
if (_stack.Count > 0)
return _stack.Peek();
else if (_global_default_graph != null)
return _global_default_graph;
else
_global_default_graph = new Graph();
if (_stack.Count == 0)
_stack.Push(new Graph());

return _global_default_graph;
return _stack.Peek();
}

public Graph get_controller(Graph g)
@@ -61,7 +55,6 @@ namespace Tensorflow
public void reset()
{
_stack.Clear();
_global_default_graph = null;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -107,7 +107,7 @@ namespace Tensorflow.NumPy
if (tensor.Handle == null)
{
if (tf.executing_eagerly())
tensor = tf.defaultSession.eval(tensor);
tensor = tf.get_default_session().eval(tensor);
}

return new NDArray(tensor, tf.executing_eagerly());


+ 1
- 1
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow.NumPy
{
if (_handle is null)
{
tensor = tf.defaultSession.eval(tensor);
tensor = tf.get_default_session().eval(tensor);
_handle = tensor.Handle;
}



+ 1
- 1
src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Variables
{
// gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true);

tf.device(_handle_device);
// tf.device(_handle_device);
tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp",
new[] { _tensor },
new object[] { "ignore_lookup_error", true }, 0);


+ 28
- 85
src/TensorFlowNET.Core/ops.threading.cs View File

@@ -1,70 +1,15 @@
using System.Threading;
using System;
using System.Threading;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class ops
{
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());
private static volatile Session _singleSesson;
private static volatile DefaultGraphStack _singleGraphStack;
private static readonly object _threadingLock = new object();

public static DefaultGraphStack default_graph_stack
{
get
{
if (!isSingleThreaded)
return _defaultGraphFactory.Value;

if (_singleGraphStack == null)
{
lock (_threadingLock)
{
if (_singleGraphStack == null)
_singleGraphStack = new DefaultGraphStack();
}
}

return _singleGraphStack;
}
}

private static bool isSingleThreaded = false;

/// <summary>
/// Does this library ignore different thread accessing.
/// </summary>
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks>
public static bool IsSingleThreaded
{
get => isSingleThreaded;
set
{
if (value)
enforce_singlethreading();
else
enforce_multithreading();
}
}

/// <summary>
/// Forces the library to ignore different thread accessing.
/// </summary>
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks>
public static void enforce_singlethreading()
{
isSingleThreaded = true;
}

/// <summary>
/// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing.
/// </summary>
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks>
public static void enforce_multithreading()
{
isSingleThreaded = false;
}
[ThreadStatic]
static DefaultGraphStack default_graph_stack = new DefaultGraphStack();
[ThreadStatic]
static Session defaultSession;

/// <summary>
/// Returns the default session for the current thread.
@@ -72,19 +17,10 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session()
{
if (!isSingleThreaded)
return tf.defaultSession;
if (defaultSession == null)
defaultSession = new Session(tf.get_default_graph());

if (_singleSesson == null)
{
lock (_threadingLock)
{
if (_singleSesson == null)
_singleSesson = new Session();
}
}

return _singleSesson;
return defaultSession;
}

/// <summary>
@@ -93,15 +29,8 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session set_default_session(Session sess)
{
if (!isSingleThreaded)
return tf.defaultSession = sess;

lock (_threadingLock)
{
_singleSesson = sess;
}

return _singleSesson;
defaultSession = sess;
return sess;
}

/// <summary>
@@ -118,10 +47,18 @@ namespace Tensorflow
/// </summary>
/// <returns></returns>
public static Graph get_default_graph()
=> default_graph_stack.get_default();
{
if (default_graph_stack == null)
default_graph_stack = new DefaultGraphStack();
return default_graph_stack.get_default();
}

public static Graph set_default_graph(Graph g)
=> default_graph_stack.get_controller(g);
{
if (default_graph_stack == null)
default_graph_stack = new DefaultGraphStack();
return default_graph_stack.get_controller(g);
}

/// <summary>
/// Clears the default graph stack and resets the global default graph.
@@ -135,6 +72,8 @@ namespace Tensorflow
/// <returns></returns>
public static void reset_default_graph()
{
if (default_graph_stack == null)
return;
//if (!_default_graph_stack.is_cleared())
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
// "nested graphs. If you need a cleared graph, " +
@@ -143,7 +82,11 @@ namespace Tensorflow
}

public static Graph peak_default_graph()
=> default_graph_stack.peak_controller();
{
if (default_graph_stack == null)
default_graph_stack = new DefaultGraphStack();
return default_graph_stack.peak_controller();
}

public static void pop_graph()
=> default_graph_stack.pop();


+ 10
- 7
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -16,6 +16,7 @@

using Serilog;
using Serilog.Core;
using System.Threading;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Gradients;
@@ -38,12 +39,18 @@ namespace Tensorflow
public TF_DataType chars = TF_DataType.TF_STRING;
public TF_DataType @string = TF_DataType.TF_STRING;

public Status Status;
public OpDefLibrary OpDefLib;
public Context Context;
public IEagerRunner Runner;
public Logger Logger;

ThreadLocal<Status> _status = new ThreadLocal<Status>(() => new Status());
public Status Status => _status.Value;

ThreadLocal<Context> _context = new ThreadLocal<Context>(() => new Context());
public Context Context => _context.Value;

ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner());
public IEagerRunner Runner => _runner.Value;

public tensorflow()
{
Logger = new LoggerConfiguration()
@@ -51,12 +58,8 @@ namespace Tensorflow
.WriteTo.Console()
.CreateLogger();

Status = new Status();
Context = new Context();
OpDefLib = new OpDefLibrary();
ConstructThreadingObjects();
InitGradientEnvironment();
Runner = new EagerRunner();
}

public string VERSION => c_api.StringPiece(c_api.TF_Version());


+ 0
- 53
src/TensorFlowNET.Core/tensorflow.threading.cs View File

@@ -1,53 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Runtime.CompilerServices;
using System.Threading;

namespace Tensorflow
{
public partial class tensorflow
{
protected ThreadLocal<Session> defaultSessionFactory;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void ConstructThreadingObjects()
{
defaultSessionFactory = new ThreadLocal<Session>(() => new Session());
}

public Session defaultSession
{
get
{
if (!ops.IsSingleThreaded)
return defaultSessionFactory.Value;

return ops.get_default_session();
}
internal set
{
if (!ops.IsSingleThreaded)
{
defaultSessionFactory.Value = value;
return;
}

ops.set_default_session(value);
}
}
}
}

+ 1
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -12,6 +12,7 @@ using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using System.Threading;

namespace Tensorflow.Keras
{


+ 8
- 7
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -9,6 +9,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using Tensorflow.Functions;
using System.Threading;

namespace Tensorflow.Keras.Layers
{
@@ -40,24 +41,24 @@ namespace Tensorflow.Keras.Layers
return MakOp(inputs);
}

ConcreteFunction function;
ThreadLocal<ConcreteFunction> function = new ThreadLocal<ConcreteFunction>();
Tensors DeFunCall(Tensors inputs)
{
if(function == null)
if (function.Value == null)
{
function = new ConcreteFunction(name);
function.Enter();
function.Value = new ConcreteFunction(name);
function.Value.Enter();

int i = 0;
var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray();
var graph_outputs = MakOp(graph_inputs);
graph_outputs = mark_as_return(graph_outputs);

function.ToGraph(graph_inputs, graph_outputs);
function.Exit();
function.Value.ToGraph(graph_inputs, graph_outputs);
function.Value.Exit();
}

var outputs = function.FilteredCall(inputs);
var outputs = function.Value.FilteredCall(inputs);
return outputs;
}



+ 52
- 67
test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs View File

@@ -24,14 +24,12 @@ namespace TensorFlowNET.UnitTest
{
Assert.IsNull(tf.peak_default_graph());

using (var sess = tf.Session())
{
var default_graph = tf.get_default_graph();
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
using var sess = tf.Session();
var default_graph = tf.get_default_graph();
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
}

@@ -47,14 +45,12 @@ namespace TensorFlowNET.UnitTest
{
Assert.IsNull(tf.peak_default_graph());
//tf.Session created an other graph
using (var sess = tf.Session())
{
var default_graph = tf.get_default_graph();
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
using var sess = tf.Session();
var default_graph = tf.get_default_graph();
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
}

@@ -73,20 +69,12 @@ namespace TensorFlowNET.UnitTest
beforehand.as_default();
Assert.IsNotNull(tf.peak_default_graph());

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);

Console.WriteLine($"{tid}-{default_graph.graph_key}");

//var result = sess.run(new object[] {g, a});
//var actualDeriv = result[0].GetData<float>()[0];
//var actual = result[1].GetData<float>()[0];
}
using var sess = tf.Session();
var default_graph = tf.peak_default_graph();
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
}

@@ -114,13 +102,10 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
using (var sess = tf.Session())
using var sess = tf.Session();
for (int i = 0; i < 100; i++)
{
Tensor t = null;
for (int i = 0; i < 100; i++)
{
t = new Tensor(1);
}
var t = new Tensor(1);
}
}
}
@@ -134,12 +119,10 @@ namespace TensorFlowNET.UnitTest
void Core(int tid)
{
//tf.Session created an other graph
using (var sess = tf.Session())
using var sess = tf.Session();
for (int i = 0; i < 100; i++)
{
for (int i = 0; i < 100; i++)
{
var t = new Tensor(new int[] { 1, 2, 3 });
}
var t = new Tensor(new int[] { 1, 2, 3 });
}
}
}
@@ -147,23 +130,23 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void SessionRun()
{
MultiThreadedUnitTestExecuter.Run(8, Core);
MultiThreadedUnitTestExecuter.Run(2, Core);

//the core method
void Core(int tid)
{
tf.compat.v1.disable_eager_execution();
var graph = tf.Graph().as_default();

//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2;
using var sess = tf.Session(graph);
for (int i = 0; i < 100; i++)
{
var graph = tf.get_default_graph();
using (var sess = tf.Session(graph))
{
var result = sess.run(math);
Assert.AreEqual(result[0], 5f);
}
var result = sess.run(math);
Assert.AreEqual(result[0], 5f);
}
}
}
@@ -176,17 +159,18 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
using (var sess = tf.Session())
{
Assert.IsNotNull(tf.get_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2;
tf.compat.v1.disable_eager_execution();
var graph = tf.Graph().as_default();

var result = sess.run(math);
Assert.AreEqual(result[0], 5f);
}
using var sess = tf.Session(graph);
Assert.IsNotNull(tf.get_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2;

var result = sess.run(math);
Assert.AreEqual(result[0], 5f);
}
}

@@ -198,14 +182,12 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
using (var sess = tf.Session())
{
Assert.IsNotNull(tf.get_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2;
}
using var sess = tf.Session();
Assert.IsNotNull(tf.get_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2;
}
}

@@ -234,6 +216,10 @@ namespace TensorFlowNET.UnitTest
void Core(int tid)
{
Assert.IsNull(tf.peak_default_graph());

tf.compat.v1.disable_eager_execution();
var graph = tf.Graph().as_default();

//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK");
@@ -248,7 +234,6 @@ namespace TensorFlowNET.UnitTest
private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/");

[Ignore]
[TestMethod]
public void TF_GraphOperationByName_FromModel()
{
MultiThreadedUnitTestExecuter.Run(8, Core);


+ 95
- 0
test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs View File

@@ -0,0 +1,95 @@
using System;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Threading.Tasks;
using Tensorflow.NumPy;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class MultiThreads
{
[TestMethod]
public void Test1()
{
//Arrange
string savefile = "mymodel.h5";
var model1 = BuildModel();
model1.save_weights(savefile);
var model2 = BuildModel();

//act
model1.load_weights(savefile);
model2.load_weights(savefile);

}

[TestMethod]
public void Test2()
{
//Arrange
string savefile = "mymodel2.h5";
var model1 = BuildModel();
model1.save_weights(savefile);
model1 = BuildModel(); //recreate model

//act
model1.load_weights(savefile);

}

[TestMethod]
public void Test3Multithreading()
{
//Arrange
string savefile = "mymodel3.h5";
var model = BuildModel();
model.save_weights(savefile);

//Sanity check without multithreading
for (int i = 0; i < 2; i++)
{
Functional clone = BuildModel();
clone.load_weights(savefile);

//Predict something
clone.predict(np.array(new float[,] { { 0, 0 } }));
} //works

//act
ParallelOptions parallelOptions = new ParallelOptions();
parallelOptions.MaxDegreeOfParallelism = 1;
var input = np.array(new float[,] { { 0, 0 } });
Parallel.For(0, 1, parallelOptions, i =>
{
var clone = BuildModel();
clone.load_weights(savefile);
//Predict something
clone.predict(input);
});
}

Functional BuildModel()
{
tf.Context.reset_context();
var inputs = keras.Input(shape: 2);

// 1st dense layer
var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid);
var outputs = DenseLayer.Apply(inputs);

// build keras model
Functional model = keras.Model(inputs, outputs, name: Guid.NewGuid().ToString());
// show model summary
model.summary();

// compile keras model into tensorflow's static graph
model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()),
optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()),
metrics: new[] { "accuracy" });
return model;
}
}
}

+ 0
- 3
test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs View File

@@ -16,7 +16,6 @@ namespace TensorFlowNET.UnitTest
/// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary>
public EnforcedSinglethreadingTests()
{
ops.IsSingleThreaded = true;
}

[TestMethod, Ignore("Has to be tested manually.")]
@@ -24,8 +23,6 @@ namespace TensorFlowNET.UnitTest
{
lock (_singlethreadLocker)
{
ops.IsSingleThreaded.Should().BeTrue();

ops.uid(); //increment id by one

//the core method


Loading…
Cancel
Save