From 1f6ea31a22f711d704dd527ce90a4dfda4636cc7 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 29 Dec 2018 10:04:32 -0600 Subject: [PATCH] c_test_util.GetAttrValue --- src/TensorFlowNET.Core/Buffers/Buffer.cs | 36 ++++++++++++++----- .../Buffers/c_api.buffer.cs | 3 ++ src/TensorFlowNET.Core/Functions/Function.cs | 11 ++++++ .../Functions/c_api.function.cs | 22 ++++++++++++ .../Operations/c_api.ops.cs | 9 +++++ test/TensorFlowNET.UnitTest/GraphTest.cs | 4 +-- test/TensorFlowNET.UnitTest/c_test_util.cs | 8 +++-- 7 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/Functions/Function.cs create mode 100644 src/TensorFlowNET.Core/Functions/c_api.function.cs diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 7e387522..112afc9c 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -5,23 +5,43 @@ using System.Text; namespace Tensorflow { - public class Buffer + public class Buffer : IDisposable { private IntPtr _handle; - private TF_Buffer buffer; + private TF_Buffer buffer => Marshal.PtrToStructure(_handle); - public byte[] Data; + public byte[] Data + { + get + { + var data = new byte[buffer.length]; + if (buffer.length > 0) + Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + return data; + } + } public int Length => (int)buffer.length; - public unsafe Buffer(IntPtr handle) + public Buffer() + { + _handle = c_api.TF_NewBuffer(); + } + + public Buffer(IntPtr handle) { _handle = handle; - buffer = Marshal.PtrToStructure(_handle); - Data = new byte[buffer.length]; - if (buffer.length > 0) - Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); + } + + public static implicit operator IntPtr(Buffer buffer) + { + return buffer._handle; + } + + public void Dispose() + { + c_api.TF_DeleteBuffer(_handle); } } } diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs index 86857392..9adfd411 100644 --- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs @@ -7,6 +7,9 @@ namespace Tensorflow { public static partial class c_api { + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteBuffer(IntPtr buffer); + /// /// Useful for passing *out* a protobuf. /// diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs new file mode 100644 index 00000000..93a0590f --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class Function + { + + } +} diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs new file mode 100644 index 00000000..32c020a6 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public static partial class c_api + { + /// + /// Write out a serialized representation of `func` (as a FunctionDef protocol + /// message) to `output_func_def` (allocated by TF_NewBuffer()). + /// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() + /// is called. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, IntPtr status); + } +} diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 39b82b13..b97e620d 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -31,6 +31,15 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern string TF_OperationDevice(IntPtr oper); + /// + /// Sets `output_attr_value` to the binary-serialized AttrValue proto + /// representation of the value of the `attr_name` attr of `oper`. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); + [DllImport(TensorFlowLibName)] public static extern string TF_OperationName(IntPtr oper); diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 1cb1c7e7..30c08425 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(0, feed.NumControlInputs); Assert.AreEqual(0, feed.NumControlOutputs); - var attr_value = new AttrValue(); - c_test_util.GetAttrValue(feed, "dtype", attr_value, s); + AttrValue attr_value = null; + c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); } } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 015081d8..744afe4e 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -9,10 +9,12 @@ namespace TensorFlowNET.UnitTest { public static class c_test_util { - public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s) + public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { - var buffer = c_api.TF_NewBuffer(); - + var buffer = new Buffer(); + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.Data); + buffer.Dispose(); return s.Code == TF_Code.TF_OK; }