Browse Source

Enable Multi-threading tests.

tags/v0.20
Oceania2018 5 years ago
parent
commit
e6876306e9
2 changed files with 25 additions and 31 deletions
  1. +15
    -8
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  2. +10
    -23
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs

+ 15
- 8
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using Tensorflow.Util; using Tensorflow.Util;
@@ -22,18 +23,24 @@ namespace Tensorflow
{ {
public class op_def_registry public class op_def_registry
{ {
static Dictionary<string, OpDef> _registered_ops;
static Dictionary<string, OpDef> _registered_ops = new Dictionary<string, OpDef>();


public static Dictionary<string, OpDef> get_registered_ops() public static Dictionary<string, OpDef> get_registered_ops()
{ {
if(_registered_ops == null)
if(_registered_ops.Count == 0)
{ {
_registered_ops = new Dictionary<string, OpDef>();
using var buffer = new Buffer(c_api.TF_GetAllOpList());
using var stream = buffer.DangerousMemoryBlock.Stream();
var op_list = OpList.Parser.ParseFrom(stream);
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
lock (_registered_ops)
{
// double validation to avoid multi-thread executing
if (_registered_ops.Count > 0)
return _registered_ops;

using var buffer = new Buffer(c_api.TF_GetAllOpList());
using var stream = buffer.DangerousMemoryBlock.Stream();
var op_list = OpList.Parser.ParseFrom(stream);
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
}
} }


return _registered_ops; return _registered_ops;


+ 10
- 23
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -7,13 +7,13 @@ using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.Util;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
[TestClass] [TestClass]
public class MultithreadingTests
public class MultithreadingTests : GraphModeTestBase
{ {
[TestMethod] [TestMethod]
public void SessionCreation() public void SessionCreation()
@@ -184,7 +184,6 @@ namespace TensorFlowNET.UnitTest
} }
} }


[Ignore]
[TestMethod] [TestMethod]
public void SessionRun() public void SessionRun()
{ {
@@ -208,7 +207,6 @@ namespace TensorFlowNET.UnitTest
} }
} }


[Ignore]
[TestMethod] [TestMethod]
public void SessionRun_InsideSession() public void SessionRun_InsideSession()
{ {
@@ -231,7 +229,6 @@ namespace TensorFlowNET.UnitTest
} }
} }


[Ignore]
[TestMethod] [TestMethod]
public void SessionRun_Initialization() public void SessionRun_Initialization()
{ {
@@ -251,7 +248,6 @@ namespace TensorFlowNET.UnitTest
} }
} }


[Ignore]
[TestMethod] [TestMethod]
public void SessionRun_Initialization_OutsideSession() public void SessionRun_Initialization_OutsideSession()
{ {
@@ -268,7 +264,6 @@ namespace TensorFlowNET.UnitTest
} }
} }


[Ignore]
[TestMethod] [TestMethod]
public void TF_GraphOperationByName() public void TF_GraphOperationByName()
{ {
@@ -309,23 +304,15 @@ namespace TensorFlowNET.UnitTest
var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray();
var outp = sess.graph.OperationByName("softmax_tensor").output; var outp = sess.graph.OperationByName("softmax_tensor").output;


for (var i = 0; i < 100; i++)
for (var i = 0; i < 8; i++)
{ {
{
var data = new float[96];
FeedItem[] feeds = new FeedItem[2];

for (int f = 0; f < 2; f++)
feeds[f] = new FeedItem(inp[f], new NDArray(data));

try
{
sess.run(outp, feeds);
} catch (Exception ex)
{
Console.WriteLine(ex);
}
}
var data = new float[96];
FeedItem[] feeds = new FeedItem[2];

for (int f = 0; f < 2; f++)
feeds[f] = new FeedItem(inp[f], new NDArray(data));

sess.run(outp, feeds);
} }
} }
} }


Loading…
Cancel
Save