Browse Source

c_test_util.GetAttrValue

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
1f6ea31a22
7 changed files with 80 additions and 13 deletions
  1. +28
    -8
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
  3. +11
    -0
      src/TensorFlowNET.Core/Functions/Function.cs
  4. +22
    -0
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  5. +9
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  6. +2
    -2
      test/TensorFlowNET.UnitTest/GraphTest.cs
  7. +5
    -3
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 28
- 8
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -5,23 +5,43 @@ using System.Text;

namespace Tensorflow
{
public class Buffer
public class Buffer : IDisposable
{
private IntPtr _handle;

private TF_Buffer buffer;
private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle);

public byte[] Data;
public byte[] Data
{
get
{
var data = new byte[buffer.length];
if (buffer.length > 0)
Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
return data;
}
}

public int Length => (int)buffer.length;

public unsafe Buffer(IntPtr handle)
public Buffer()
{
_handle = c_api.TF_NewBuffer();
}

public Buffer(IntPtr handle)
{
_handle = handle;
buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
Data = new byte[buffer.length];
if (buffer.length > 0)
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length);
}

public static implicit operator IntPtr(Buffer buffer)
{
return buffer._handle;
}

public void Dispose()
{
c_api.TF_DeleteBuffer(_handle);
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Buffers/c_api.buffer.cs View File

@@ -7,6 +7,9 @@ namespace Tensorflow
{
public static partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteBuffer(IntPtr buffer);

/// <summary>
/// Useful for passing *out* a protobuf.
/// </summary>


+ 11
- 0
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class Function
{

}
}

+ 22
- 0
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public static partial class c_api
{
/// <summary>
/// Write out a serialized representation of `func` (as a FunctionDef protocol
/// message) to `output_func_def` (allocated by TF_NewBuffer()).
/// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
/// is called.
/// </summary>
/// <param name="func"></param>
/// <param name="output_func_def"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, IntPtr status);
}
}

+ 9
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -31,6 +31,15 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern string TF_OperationDevice(IntPtr oper);

/// <summary>
/// Sets `output_attr_value` to the binary-serialized AttrValue proto
/// representation of the value of the `attr_name` attr of `oper`.
/// </summary>
/// <param name="oper"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern string TF_OperationName(IntPtr oper);



+ 2
- 2
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(0, feed.NumControlInputs);
Assert.AreEqual(0, feed.NumControlOutputs);

var attr_value = new AttrValue();
c_test_util.GetAttrValue(feed, "dtype", attr_value, s);
AttrValue attr_value = null;
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s);
}
}
}

+ 5
- 3
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -9,10 +9,12 @@ namespace TensorFlowNET.UnitTest
{
public static class c_test_util
{
public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s)
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
{
var buffer = c_api.TF_NewBuffer();

var buffer = new Buffer();
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer.Data);
buffer.Dispose();
return s.Code == TF_Code.TF_OK;
}



Loading…
Cancel
Save