@@ -18,7 +18,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status); | |||||
public static extern TF_AttrMetadata TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// Fills in `value` with the value of the attribute `attr_name`. `value` must | /// Fills in `value` with the value of the attribute `attr_name`. `value` must | ||||
@@ -71,7 +71,7 @@ namespace Tensorflow | |||||
/// <param name="value">const void*</param> | /// <param name="value">const void*</param> | ||||
/// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length); | |||||
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -145,6 +145,11 @@ namespace Tensorflow | |||||
return ret; | return ret; | ||||
} | } | ||||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||||
{ | |||||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||||
} | |||||
public NodeDef GetNodeDef() | public NodeDef GetNodeDef() | ||||
{ | { | ||||
using (var s = new Status()) | using (var s = new Status()) | ||||
@@ -8,6 +8,11 @@ namespace Tensorflow | |||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
public OperationDescription(Graph graph, string opType, string opName) | |||||
{ | |||||
_handle = c_api.TF_NewOperation(graph, opType, opName); | |||||
} | |||||
public OperationDescription(IntPtr handle) | public OperationDescription(IntPtr handle) | ||||
{ | { | ||||
_handle = handle; | _handle = handle; | ||||
@@ -18,6 +23,16 @@ namespace Tensorflow | |||||
c_api.TF_AddInputList(_handle, inputs, inputs.Length); | c_api.TF_AddInputList(_handle, inputs, inputs.Length); | ||||
} | } | ||||
public void SetAttrType(string attr_name, TF_DataType value) | |||||
{ | |||||
c_api.TF_SetAttrType(_handle, attr_name, value); | |||||
} | |||||
public void SetAttrShape(string attr_name, long[] dims) | |||||
{ | |||||
c_api.TF_SetAttrShape(_handle, attr_name, dims, dims.Length); | |||||
} | |||||
public Operation FinishOperation(Status status) | public Operation FinishOperation(Status status) | ||||
{ | { | ||||
return c_api.TF_FinishOperation(_handle, status); | return c_api.TF_FinishOperation(_handle, status); | ||||
@@ -6,7 +6,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public struct TF_AttrMetadata | public struct TF_AttrMetadata | ||||
{ | { | ||||
public char is_list; | |||||
public byte is_list; | |||||
public long list_size; | public long list_size; | ||||
public TF_AttrType type; | public TF_AttrType type; | ||||
public long total_size; | public long total_size; | ||||
@@ -46,6 +46,16 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | ||||
/// <summary> | |||||
/// Operation will only be added to *graph when TF_FinishOperation() is | |||||
/// called (assuming TF_FinishOperation() does not return an error). | |||||
/// *graph must not be deleted until after TF_FinishOperation() is | |||||
/// called. | |||||
/// </summary> | |||||
/// <param name="graph">TF_Graph*</param> | |||||
/// <param name="opType">const char*</param> | |||||
/// <param name="oper_name">const char*</param> | |||||
/// <returns>TF_OperationDescription*</returns> | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | ||||
@@ -42,6 +42,13 @@ namespace TensorFlowNET.Examples | |||||
// Mean squared error | // Mean squared error | ||||
var sub = pred - Y; | var sub = pred - Y; | ||||
var pow = tf.pow(sub, 2); | var pow = tf.pow(sub, 2); | ||||
var reduce = tf.reduce_sum(pow); | var reduce = tf.reduce_sum(pow); | ||||
var cost = reduce / (2d * n_samples); | var cost = reduce / (2d * n_samples); | ||||
@@ -65,12 +65,11 @@ namespace TensorFlowNET.UnitTest | |||||
public void String() | public void String() | ||||
{ | { | ||||
var desc = init("string"); | var desc = init("string"); | ||||
var handle = Marshal.StringToHGlobalAnsi("bunny"); | |||||
c_api.TF_SetAttrString(desc, "v", handle, 5); | |||||
c_api.TF_SetAttrString(desc, "v", "bunny", 5); | |||||
//var oper = c_api.TF_FinishOperation(desc, s_); | |||||
//ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||||
//EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | |||||
var oper = c_api.TF_FinishOperation(desc, s_); | |||||
ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||||
EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | |||||
//var value = new char[5]; | //var value = new char[5]; | ||||
//c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_); | //c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_); | ||||
@@ -78,6 +77,17 @@ namespace TensorFlowNET.UnitTest | |||||
//EXPECT_EQ("bunny", value, 5)); | //EXPECT_EQ("bunny", value, 5)); | ||||
} | } | ||||
[TestMethod] | |||||
public void GetAttributesTest() | |||||
{ | |||||
var desc = graph_.NewOperation("Placeholder", "node"); | |||||
desc.SetAttrType("dtype", TF_DataType.TF_FLOAT); | |||||
long[] ref_shape = new long[3] { 1, 2, 3 }; | |||||
desc.SetAttrShape("shape", ref_shape); | |||||
var oper = desc.FinishOperation(s_); | |||||
var metadata = oper.GetAttributeMetadata("shape", s_); | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
graph_.Dispose(); | graph_.Dispose(); | ||||