Browse Source

fixed c_api.TF_OperationGetAttrMetadata #86

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
dba38209cc
7 changed files with 55 additions and 8 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Attributes/c_api.ops.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +15
    -0
      src/TensorFlowNET.Core/Operations/OperationDescription.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs
  5. +10
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  6. +7
    -0
      test/TensorFlowNET.Examples/LinearRegression.cs
  7. +15
    -5
      test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs

+ 2
- 2
src/TensorFlowNET.Core/Attributes/c_api.ops.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status);
public static extern TF_AttrMetadata TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status);

/// <summary>
/// Fills in `value` with the value of the attribute `attr_name`. `value` must
@@ -71,7 +71,7 @@ namespace Tensorflow
/// <param name="value">const void*</param>
/// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length);
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length);

/// <summary>
///


+ 5
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -145,6 +145,11 @@ namespace Tensorflow
return ret;
}

public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
{
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
}

public NodeDef GetNodeDef()
{
using (var s = new Status())


+ 15
- 0
src/TensorFlowNET.Core/Operations/OperationDescription.cs View File

@@ -8,6 +8,11 @@ namespace Tensorflow
{
private IntPtr _handle;

public OperationDescription(Graph graph, string opType, string opName)
{
_handle = c_api.TF_NewOperation(graph, opType, opName);
}

public OperationDescription(IntPtr handle)
{
_handle = handle;
@@ -18,6 +23,16 @@ namespace Tensorflow
c_api.TF_AddInputList(_handle, inputs, inputs.Length);
}

public void SetAttrType(string attr_name, TF_DataType value)
{
c_api.TF_SetAttrType(_handle, attr_name, value);
}

public void SetAttrShape(string attr_name, long[] dims)
{
c_api.TF_SetAttrShape(_handle, attr_name, dims, dims.Length);
}

public Operation FinishOperation(Status status)
{
return c_api.TF_FinishOperation(_handle, status);


+ 1
- 1
src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs View File

@@ -6,7 +6,7 @@ namespace Tensorflow
{
public struct TF_AttrMetadata
{
public char is_list;
public byte is_list;
public long list_size;
public TF_AttrType type;
public long total_size;


+ 10
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -46,6 +46,16 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);

/// <summary>
/// Operation will only be added to *graph when TF_FinishOperation() is
/// called (assuming TF_FinishOperation() does not return an error).
/// *graph must not be deleted until after TF_FinishOperation() is
/// called.
/// </summary>
/// <param name="graph">TF_Graph*</param>
/// <param name="opType">const char*</param>
/// <param name="oper_name">const char*</param>
/// <returns>TF_OperationDescription*</returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);



+ 7
- 0
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -42,6 +42,13 @@ namespace TensorFlowNET.Examples
// Mean squared error
var sub = pred - Y;
var pow = tf.pow(sub, 2);







var reduce = tf.reduce_sum(pow);
var cost = reduce / (2d * n_samples);



+ 15
- 5
test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs View File

@@ -65,12 +65,11 @@ namespace TensorFlowNET.UnitTest
public void String()
{
var desc = init("string");
var handle = Marshal.StringToHGlobalAnsi("bunny");
c_api.TF_SetAttrString(desc, "v", handle, 5);
c_api.TF_SetAttrString(desc, "v", "bunny", 5);

//var oper = c_api.TF_FinishOperation(desc, s_);
//ASSERT_EQ(TF_Code.TF_OK, s_.Code);
//EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5);
var oper = c_api.TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_Code.TF_OK, s_.Code);
EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5);
//var value = new char[5];

//c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_);
@@ -78,6 +77,17 @@ namespace TensorFlowNET.UnitTest
//EXPECT_EQ("bunny", value, 5));
}

[TestMethod]
public void GetAttributesTest()
{
var desc = graph_.NewOperation("Placeholder", "node");
desc.SetAttrType("dtype", TF_DataType.TF_FLOAT);
long[] ref_shape = new long[3] { 1, 2, 3 };
desc.SetAttrShape("shape", ref_shape);
var oper = desc.FinishOperation(s_);
var metadata = oper.GetAttributeMetadata("shape", s_);
}

public void Dispose()
{
graph_.Dispose();


Loading…
Cancel
Save