From 9d6525ef9f85da8685ab2fbb6c0c091f6596d699 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Tue, 27 Aug 2019 23:59:40 +0300 Subject: [PATCH] Buffer: Revamped and all perf-optted all use-cases. - Fixed all test cases to use using(Buffer) - Fixed all test cases to explicitly specify session --- src/TensorFlowNET.Core/Buffers/Buffer.cs | 102 ++++++++++++++---- .../Framework/op_def_registry.py.cs | 14 +-- .../Graphs/Graph.Control.cs | 4 +- src/TensorFlowNET.Core/Graphs/Graph.Export.cs | 23 ++-- .../Graphs/Graph.Operation.cs | 7 +- .../Operations/Operation.cs | 19 ++-- .../Sessions/SessionOptions.cs | 4 +- src/TensorFlowNET.Core/ops.cs | 4 +- src/TensorFlowNET.Core/tensorflow.cs | 3 +- .../BasicModels/LogisticRegression.cs | 6 +- .../Basics/AssignTests.cs | 26 ++--- .../CApiGradientsTest.cs | 22 ++-- test/TensorFlowNET.UnitTest/GraphTest.cs | 5 - test/TensorFlowNET.UnitTest/NameScopeTest.cs | 37 ++++++- test/TensorFlowNET.UnitTest/OperationsTest.cs | 3 +- test/TensorFlowNET.UnitTest/PythonTest.cs | 2 +- test/TensorFlowNET.UnitTest/SessionTest.cs | 4 +- test/TensorFlowNET.UnitTest/VariableTest.cs | 2 +- test/TensorFlowNET.UnitTest/c_test_util.cs | 17 +-- .../ops_test/CreateOpFromTfOperationTest.cs | 21 ++-- 20 files changed, 216 insertions(+), 109 deletions(-) diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 396fb311..c08d3175 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -15,58 +15,116 @@ ******************************************************************************/ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using NumSharp.Backends.Unmanaged; +using static Tensorflow.c_api; namespace Tensorflow { + /// + /// Represents a TF_Buffer that can be passed to Tensorflow. + /// public class Buffer : DisposableObject { - private TF_Buffer buffer => Marshal.PtrToStructure(_handle); + private unsafe TF_Buffer buffer + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => *bufferptr; + } + + private unsafe TF_Buffer* bufferptr + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (TF_Buffer*) _handle; + } - public byte[] Data + /// + /// The memory block representing this buffer. + /// + /// The deallocator is set to null. + public UnmanagedMemoryBlock MemoryBlock { - get + get { - var data = new byte[buffer.length]; - if (data.Length > 0) - Marshal.Copy(buffer.data, data, 0, data.Length); - return data; + unsafe + { + EnsureNotDisposed(); + var buff = (TF_Buffer*) _handle; + return new UnmanagedMemoryBlock((byte*) buff->data.ToPointer(), (long) buff->length); + } } } - public int Length => (int)buffer.length; - - public Buffer() + /// + /// The bytes length of this buffer. + /// + public ulong Length { - _handle = c_api.TF_NewBuffer(); + get + { + EnsureNotDisposed(); + return buffer.length; + } } - public Buffer(IntPtr handle) + public Buffer() => _handle = TF_NewBuffer(); + + internal Buffer(IntPtr handle) { + if (handle == IntPtr.Zero) + throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); + _handle = handle; } - public Buffer(byte[] data) - { - var dst = Marshal.AllocHGlobal(data.Length); - Marshal.Copy(data, 0, dst, data.Length); + public Buffer(byte[] data) : this(_toBuffer(data)) + { } - _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); + private static IntPtr _toBuffer(byte[] data) + { + if (data == null) + throw new ArgumentNullException(nameof(data)); - Marshal.FreeHGlobal(dst); + unsafe + { + fixed (byte* src = data) + return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); + } } public static implicit operator IntPtr(Buffer buffer) { + buffer.EnsureNotDisposed(); return buffer._handle; } - public static implicit operator byte[](Buffer buffer) + public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. + + /// + /// Copies this buffer's contents onto a array. + /// + public byte[] ToArray() { - return buffer.Data; + EnsureNotDisposed(); + + unsafe + { + var len = buffer.length; + if (len == 0) + return Array.Empty(); + + byte[] data = new byte[len]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); + + return data; + } } protected override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TF_DeleteBuffer(handle); + { + TF_DeleteBuffer(handle); + } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs index 9f9b4ad7..8a2bc5c3 100644 --- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System.Collections.Generic; +using System.IO; +using Tensorflow.Util; namespace Tensorflow { @@ -27,12 +29,12 @@ namespace Tensorflow if(_registered_ops == null) { _registered_ops = new Dictionary(); - var handle = c_api.TF_GetAllOpList(); - var buffer = new Buffer(handle); - var op_list = OpList.Parser.ParseFrom(buffer); - - foreach (var op_def in op_list.Op) - _registered_ops[op_def.Name] = op_def; + using (var buffer = new Buffer(c_api.TF_GetAllOpList())) + { + var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + } } return _registered_ops; diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 4a3ac793..c97e1b6f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using Tensorflow.Operations; @@ -66,8 +67,9 @@ namespace Tensorflow /// within the context should have control dependencies on /// `control_inputs`. /// + [SuppressMessage("ReSharper", "CoVariantArrayConversion")] public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) - => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + => control_dependencies((object[])control_inputs); /// /// Returns a context manager that specifies control dependencies. diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 17828c73..4a7e0ed8 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -14,6 +14,9 @@ limitations under the License. ******************************************************************************/ +using System.IO; +using Tensorflow.Util; + namespace Tensorflow { public partial class Graph @@ -23,21 +26,19 @@ namespace Tensorflow var buffer = new Buffer(); c_api.TF_GraphToGraphDef(_handle, buffer, s); s.Check(true); - // var def = GraphDef.Parser.ParseFrom(buffer); - // buffer.Dispose(); return buffer; } private GraphDef _as_graph_def(bool add_shapes = false) { - var status = new Status(); - var buffer = ToGraphDef(status); - status.Check(true); - status.Dispose(); - - var def = GraphDef.Parser.ParseFrom(buffer); - buffer.Dispose(); + GraphDef def; + using (var status = new Status()) + using (var buffer = ToGraphDef(status)) + { + status.Check(true); + def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } // Strip the experimental library field iff it's empty. // if(def.Library.Function.Count == 0) @@ -45,7 +46,7 @@ namespace Tensorflow return def; } - public GraphDef as_graph_def(bool add_shapes = false) + public GraphDef as_graph_def(bool add_shapes = false) => _as_graph_def(add_shapes); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index ded3ca9c..0e28dd9a 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow @@ -30,7 +31,7 @@ namespace Tensorflow using (var status = new Status()) { c_api.TF_GraphGetOpDef(_handle, type, buffer, status); - return OpDef.Parser.ParseFrom(buffer.Data); + return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } @@ -71,7 +72,7 @@ namespace Tensorflow public ITensorOrOperation[] get_operations() { - return _nodes_by_name.Values.Select(x => x).ToArray(); + return _nodes_by_name.Values.ToArray(); } /// @@ -85,7 +86,7 @@ namespace Tensorflow public ITensorOrOperation _get_operation_by_name_unsafe(string name) { - return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; + return _nodes_by_name.TryGetValue(name, out var val) ? val : null; } public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 059290f4..5fff9ade 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -17,7 +17,9 @@ using Google.Protobuf.Collections; using System; using System.Collections.Generic; +using System.IO; using System.Linq; +using Tensorflow.Util; namespace Tensorflow { @@ -226,9 +228,12 @@ namespace Tensorflow using (var status = new Status()) using (var buf = new Buffer()) { - c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); - status.Check(true); - x = AttrValue.Parser.ParseFrom(buf); + unsafe + { + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); + } } string oneof_value = x.ValueCase.ToString(); @@ -259,7 +264,7 @@ namespace Tensorflow { c_api.TF_OperationToNodeDef(_handle, buffer, s); s.Check(); - return NodeDef.Parser.ParseFrom(buffer); + return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } @@ -299,8 +304,7 @@ namespace Tensorflow /// public TF_Output _tf_output(int output_idx) { - var tf_output = new TF_Output(op, output_idx); - return tf_output; + return new TF_Output(op, output_idx); } /// @@ -308,8 +312,7 @@ namespace Tensorflow /// public TF_Input _tf_input(int input_idx) { - var tf_input = new TF_Input(op, input_idx); - return tf_input; + return new TF_Input(op, input_idx); } } } diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index ed99b7fe..112543fe 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -37,8 +37,8 @@ namespace Tensorflow public void SetConfig(ConfigProto config) { - var bytes = config.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); + var bytes = config.ToByteArray(); //TODO! we can use WriteTo + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak Marshal.Copy(bytes, 0, proto, bytes.Length); using (var status = new Status()) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index e485ba6f..1dc8eb56 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -230,8 +230,8 @@ namespace Tensorflow // Add attrs foreach (var attr in node_def.Attr) { - var bytes = attr.Value.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak Marshal.Copy(bytes, 0, proto, bytes.Length); uint len = (uint)bytes.Length; c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index da873722..ac6c38dc 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -64,8 +64,7 @@ namespace Tensorflow public Session Session() { - defaultSession = new Session(); - return defaultSession; + return new Session(); } public Session Session(Graph graph) diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 73d40d28..3116e6f4 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples // Display logs per epoch step if ((epoch + 1) % display_step == 0) - print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); + print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); sw.Reset(); } @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); // Calculate accuracy var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); - print($"Accuracy: {acc.ToString("F4")}"); + float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); + print($"Accuracy: {acc:F4}"); return acc > 0.9; } diff --git a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs index 15d9b819..6c593929 100644 --- a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs +++ b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics @@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics var expected = new[] { false, true, false, false, true, false, true }; var spike = tf.Variable(false); - - spike.initializer.run(); - foreach (var i in range(1, 2)) + using (var sess = new Session()) { - if (raw_data[i] - raw_data[i - 1] > 5d) - { - var updater = tf.assign(spike, tf.constant(true)); - updater.eval(); - } - else + spike.initializer.run(session: sess); + foreach (var i in range(1, 2)) { - tf.assign(spike, tf.constant(true)).eval(); - } + if (raw_data[i] - raw_data[i - 1] > 5d) + { + var updater = tf.assign(spike, tf.constant(true)); + updater.eval(sess); + } else + { + tf.assign(spike, tf.constant(true)).eval(sess); + } - Assert.AreEqual((bool)spike.eval(), expected[i - 1]); + Assert.AreEqual((bool) spike.eval(), expected[i - 1]); + } } } } diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs index 33e38870..58609c17 100644 --- a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs @@ -2,6 +2,7 @@ using NumSharp; using System; using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest @@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest private bool GetGraphDef(Graph graph, out GraphDef graph_def) { graph_def = null; - var s = new Status(); - var buffer = new Buffer(); - c_api.TF_GraphToGraphDef(graph, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)); - if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); - buffer.Dispose(); - s.Dispose(); - return ret; + using (var s = new Status()) + { + using (var buffer = new Buffer()) + { + c_api.TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)); + if (ret) + graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + return ret; + } + } } private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index f5431e01..94da6d97 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(feed2, control_inputs[1]); // Export to a graph def so we can import a graph with control dependencies - graph_def.Dispose(); graph_def = new Buffer(); c_api.TF_GraphToGraphDef(graph, graph_def, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); @@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(feed4, control_inputs[1]); c_api.TF_DeleteImportGraphDefOptions(opts); - c_api.TF_DeleteBuffer(graph_def); // Can add nodes to the imported graph without trouble. c_test_util.Add(feed, scalar, graph, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - - graph.Dispose(); - s.Dispose(); } /// diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 4ff50deb..3d763b38 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -1,4 +1,5 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; using static Tensorflow.Binding; @@ -42,5 +43,39 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual("", g._name_stack); } + + [TestMethod] + public void NestedNameScope_Using() + { + Graph g = tf.Graph().as_default(); + + using (var name = new ops.NameScope("scope1")) + { + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + using (var name2 = new ops.NameScope("scope2")) + { + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + } + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + } + + ; + + g.Dispose(); + + Assert.AreEqual("", g._name_stack); + } } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 0caa5259..226a4839 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using NumSharp; using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; using static Tensorflow.Binding; @@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); - var op_list = OpList.Parser.ParseFrom(buffer); + var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); var _registered_ops = new Dictionary(); foreach (var op_def in op_list.Op) diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 701b4b4b..d2ae36d7 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest { using (var sess = tf.Session()) { - var ndarray=tensor.eval(); + var ndarray=tensor.eval(sess); if (typeof(T) == typeof(double)) { double x = ndarray; diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 8fd4dc8a..62d7c63d 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -72,8 +72,6 @@ namespace TensorFlowNET.UnitTest // Clean up csession.CloseAndDelete(s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - graph.Dispose(); - s.Dispose(); } [TestMethod] @@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest var c = math_ops.matmul(a, b, name: "matmul"); using (var sess = tf.Session()) { - var result = c.eval(); + var result = c.eval(sess); Assert.AreEqual(6, result.Data()[0]); } } diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 211d7b65..4d9d1059 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest { sess.run(init_op); // o some work with the model. - inc_v1.op.run(); + inc_v1.op.run(session: sess); } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 1b6909e7..627d7c2f 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -1,4 +1,6 @@ -using Tensorflow; +using System.Diagnostics.CodeAnalysis; +using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest @@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest return op; } + [SuppressMessage("ReSharper", "RedundantAssignment")] public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { - var buffer = new Buffer(); - c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - attr_value = AttrValue.Parser.ParseFrom(buffer); - buffer.Dispose(); + using (var buffer = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } + return s.Code == TF_Code.TF_OK; } @@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest { c_api.TF_GraphToGraphDef(graph, buffer, s); s.Check(); - return GraphDef.Parser.ParseFrom(buffer); + return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index cdbd5f14..310ac634 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test [TestMethod] public void TestShape() { - var g = tf.Graph().as_default(); - - var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); - var op = g._create_op_from_tf_operation(c_op); - - Assert.AreEqual("myop", op.name); - Assert.AreEqual("Identity", op.type); - Assert.AreEqual(1, len(op.outputs)); - assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); + using (var g = tf.Graph().as_default()) + { + var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); + var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var op = g._create_op_from_tf_operation(c_op); + + Assert.AreEqual("myop", op.name); + Assert.AreEqual("Identity", op.type); + Assert.AreEqual(1, len(op.outputs)); + assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); + } } [TestMethod]