Browse Source

Enable Multi-threading tests.

tags/v0.20
Oceania2018 5 years ago
parent
commit
e6876306e9
2 changed files with 25 additions and 31 deletions
  1. +15
    -8
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  2. +10
    -23
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs

+ 15
- 8
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.IO;
using Tensorflow.Util;
@@ -22,18 +23,24 @@ namespace Tensorflow
{
public class op_def_registry
{
static Dictionary<string, OpDef> _registered_ops;
static Dictionary<string, OpDef> _registered_ops = new Dictionary<string, OpDef>();

public static Dictionary<string, OpDef> get_registered_ops()
{
if(_registered_ops == null)
if(_registered_ops.Count == 0)
{
_registered_ops = new Dictionary<string, OpDef>();
using var buffer = new Buffer(c_api.TF_GetAllOpList());
using var stream = buffer.DangerousMemoryBlock.Stream();
var op_list = OpList.Parser.ParseFrom(stream);
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
lock (_registered_ops)
{
// double validation to avoid multi-thread executing
if (_registered_ops.Count > 0)
return _registered_ops;

using var buffer = new Buffer(c_api.TF_GetAllOpList());
using var stream = buffer.DangerousMemoryBlock.Stream();
var op_list = OpList.Parser.ParseFrom(stream);
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
}
}

return _registered_ops;


+ 10
- 23
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -7,13 +7,13 @@ using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using Tensorflow.Util;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class MultithreadingTests
public class MultithreadingTests : GraphModeTestBase
{
[TestMethod]
public void SessionCreation()
@@ -184,7 +184,6 @@ namespace TensorFlowNET.UnitTest
}
}

[Ignore]
[TestMethod]
public void SessionRun()
{
@@ -208,7 +207,6 @@ namespace TensorFlowNET.UnitTest
}
}

[Ignore]
[TestMethod]
public void SessionRun_InsideSession()
{
@@ -231,7 +229,6 @@ namespace TensorFlowNET.UnitTest
}
}

[Ignore]
[TestMethod]
public void SessionRun_Initialization()
{
@@ -251,7 +248,6 @@ namespace TensorFlowNET.UnitTest
}
}

[Ignore]
[TestMethod]
public void SessionRun_Initialization_OutsideSession()
{
@@ -268,7 +264,6 @@ namespace TensorFlowNET.UnitTest
}
}

[Ignore]
[TestMethod]
public void TF_GraphOperationByName()
{
@@ -309,23 +304,15 @@ namespace TensorFlowNET.UnitTest
var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray();
var outp = sess.graph.OperationByName("softmax_tensor").output;

for (var i = 0; i < 100; i++)
for (var i = 0; i < 8; i++)
{
{
var data = new float[96];
FeedItem[] feeds = new FeedItem[2];

for (int f = 0; f < 2; f++)
feeds[f] = new FeedItem(inp[f], new NDArray(data));

try
{
sess.run(outp, feeds);
} catch (Exception ex)
{
Console.WriteLine(ex);
}
}
var data = new float[96];
FeedItem[] feeds = new FeedItem[2];

for (int f = 0; f < 2; f++)
feeds[f] = new FeedItem(inp[f], new NDArray(data));

sess.run(outp, feeds);
}
}
}


Loading…
Cancel
Save