diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs index e82805e8..dfc654a4 100644 --- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; using System.IO; using Tensorflow.Util; @@ -22,18 +23,24 @@ namespace Tensorflow { public class op_def_registry { - static Dictionary _registered_ops; + static Dictionary _registered_ops = new Dictionary(); public static Dictionary get_registered_ops() { - if(_registered_ops == null) + if(_registered_ops.Count == 0) { - _registered_ops = new Dictionary(); - using var buffer = new Buffer(c_api.TF_GetAllOpList()); - using var stream = buffer.DangerousMemoryBlock.Stream(); - var op_list = OpList.Parser.ParseFrom(stream); - foreach (var op_def in op_list.Op) - _registered_ops[op_def.Name] = op_def; + lock (_registered_ops) + { + // double validation to avoid multi-thread executing + if (_registered_ops.Count > 0) + return _registered_ops; + + using var buffer = new Buffer(c_api.TF_GetAllOpList()); + using var stream = buffer.DangerousMemoryBlock.Stream(); + var op_list = OpList.Parser.ParseFrom(stream); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + } } return _registered_ops; diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs index 543248bf..edf30177 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -7,13 +7,13 @@ using FluentAssertions; using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using Tensorflow; -using Tensorflow.Util; +using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest { [TestClass] - public class MultithreadingTests + public class MultithreadingTests : GraphModeTestBase { [TestMethod] public void SessionCreation() @@ -184,7 +184,6 @@ namespace TensorFlowNET.UnitTest } } - [Ignore] [TestMethod] public void SessionRun() { @@ -208,7 +207,6 @@ namespace TensorFlowNET.UnitTest } } - [Ignore] [TestMethod] public void SessionRun_InsideSession() { @@ -231,7 +229,6 @@ namespace TensorFlowNET.UnitTest } } - [Ignore] [TestMethod] public void SessionRun_Initialization() { @@ -251,7 +248,6 @@ namespace TensorFlowNET.UnitTest } } - [Ignore] [TestMethod] public void SessionRun_Initialization_OutsideSession() { @@ -268,7 +264,6 @@ namespace TensorFlowNET.UnitTest } } - [Ignore] [TestMethod] public void TF_GraphOperationByName() { @@ -309,23 +304,15 @@ namespace TensorFlowNET.UnitTest var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); var outp = sess.graph.OperationByName("softmax_tensor").output; - for (var i = 0; i < 100; i++) + for (var i = 0; i < 8; i++) { - { - var data = new float[96]; - FeedItem[] feeds = new FeedItem[2]; - - for (int f = 0; f < 2; f++) - feeds[f] = new FeedItem(inp[f], new NDArray(data)); - - try - { - sess.run(outp, feeds); - } catch (Exception ex) - { - Console.WriteLine(ex); - } - } + var data = new float[96]; + FeedItem[] feeds = new FeedItem[2]; + + for (int f = 0; f < 2; f++) + feeds[f] = new FeedItem(inp[f], new NDArray(data)); + + sess.run(outp, feeds); } } }