@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -14,16 +15,16 @@ namespace Tensorflow | |||
private Status status = new Status(); | |||
public string name { get; } | |||
public string optype { get; } | |||
public string device { get; } | |||
public int NumOutputs { get; } | |||
public TF_DataType OutputType { get; } | |||
public int OutputListLength { get; } | |||
public int NumInputs { get; } | |||
public int NumConsumers { get; } | |||
public int NumControlInputs { get; } | |||
public int NumControlOutputs { get; } | |||
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
private Tensor[] _outputs; | |||
public Tensor[] outputs => _outputs; | |||
@@ -35,17 +36,6 @@ namespace Tensorflow | |||
return; | |||
_handle = handle; | |||
name = c_api.TF_OperationName(_handle); | |||
optype = c_api.TF_OperationOpType(_handle); | |||
device = "";// c_api.TF_OperationDevice(_handle); | |||
NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||
OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||
OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); | |||
NumInputs = c_api.TF_OperationNumInputs(_handle); | |||
NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||
NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); | |||
NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); | |||
} | |||
public Operation(Graph g, string opType, string oper_name) | |||
@@ -62,8 +52,8 @@ namespace Tensorflow | |||
Graph = g; | |||
_id_value = Graph._next_id(); | |||
_handle = ops._create_c_op(g, node_def, inputs); | |||
NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||
_outputs = new Tensor[NumOutputs]; | |||
for (int i = 0; i < NumOutputs; i++) | |||
@@ -38,7 +38,7 @@ namespace Tensorflow | |||
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern string TF_OperationDevice(IntPtr oper); | |||
public static extern IntPtr TF_OperationDevice(IntPtr oper); | |||
/// <summary> | |||
/// Sets `output_attr_value` to the binary-serialized AttrValue proto | |||
@@ -50,13 +50,13 @@ namespace Tensorflow | |||
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern string TF_OperationName(IntPtr oper); | |||
public static extern IntPtr TF_OperationName(IntPtr oper); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationNumInputs(IntPtr oper); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern string TF_OperationOpType(IntPtr oper); | |||
public static extern IntPtr TF_OperationOpType(IntPtr oper); | |||
/// <summary> | |||
/// Get the number of control inputs to an operation. | |||
@@ -30,12 +30,12 @@ namespace Tensorflow | |||
// Add inputs | |||
if(inputs != null && inputs.Count > 0) | |||
{ | |||
/*foreach (var op_input in inputs) | |||
foreach (var op_input in inputs) | |||
{ | |||
c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | |||
}*/ | |||
} | |||
c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||
//c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||
} | |||
var status = new Status(); | |||
@@ -15,7 +15,7 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Error message | |||
/// </summary> | |||
public string Message => c_api.TF_Message(_handle); | |||
public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); | |||
/// <summary> | |||
/// Error code | |||
@@ -12,7 +12,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="s"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern void TF_DeleteStatus(IntPtr s); | |||
public static extern void TF_DeleteStatus(IntPtr s); | |||
/// <summary> | |||
/// Return the code record in *s. | |||
@@ -20,7 +20,7 @@ namespace Tensorflow | |||
/// <param name="s"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe TF_Code TF_GetCode(IntPtr s); | |||
public static extern TF_Code TF_GetCode(IntPtr s); | |||
/// <summary> | |||
/// Return a pointer to the (null-terminated) error message in *s. | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
/// <param name="s"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe string TF_Message(IntPtr s); | |||
public static extern IntPtr TF_Message(IntPtr s); | |||
/// <summary> | |||
/// Return a new status object. | |||
@@ -12,20 +12,27 @@ namespace Tensorflow | |||
/// The API leans towards simplicity and uniformity instead of convenience | |||
/// since most usage will be by language specific wrappers. | |||
/// | |||
/// The params type mapping between .net and c_api | |||
/// The params type mapping between c_api and .NET | |||
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | |||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
/// struct => struct (TF_Output output) => (TF_Output output) | |||
/// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||
/// const char* => string | |||
/// int32_t => int | |||
/// int64_t* => long[] | |||
/// size_t* => unlong[] | |||
/// void* => IntPtr | |||
/// string => IntPtr c_api.StringPiece(IntPtr) | |||
/// </summary> | |||
public static partial class c_api | |||
{ | |||
public const string TensorFlowLibName = "tensorflow"; | |||
public static string StringPiece(IntPtr handle) | |||
{ | |||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||
} | |||
public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); | |||
[DllImport(TensorFlowLibName)] | |||
@@ -37,7 +37,7 @@ namespace Tensorflow | |||
context.default_execution_mode = Context.EAGER_MODE; | |||
} | |||
public static string VERSION => Marshal.PtrToStringAnsi(c_api.TF_Version()); | |||
public static string VERSION => c_api.StringPiece(c_api.TF_Version()); | |||
public static Graph get_default_graph() | |||
{ | |||
@@ -39,17 +39,14 @@ namespace TensorFlowNET.UnitTest | |||
// Test not found errors in TF_Operation*() query functions. | |||
Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
//Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
//Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message); | |||
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
// Make a constant oper with the scalar "3". | |||
var three = c_test_util.ScalarConst(3, graph, s); | |||
// Add oper. | |||
var add = c_test_util.Add(feed, three, graph, s); | |||
NodeDef node_def = null; | |||
c_test_util.GetNodeDef(feed, ref node_def); | |||
} | |||
} | |||
} |
@@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest | |||
public void addInConstant() | |||
{ | |||
var a = tf.constant(4.0f); | |||
var b = tf.placeholder(tf.float32); | |||
var b = tf.constant(5.0f); | |||
var c = tf.add(a, b); | |||
using (var sess = tf.Session()) | |||
@@ -23,11 +23,13 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||
c_api.TF_AddInputList(desc, new TF_Output[] | |||
var inputs = new TF_Output[] | |||
{ | |||
new TF_Output(l, 0), | |||
new TF_Output(r, 0), | |||
}, 2); | |||
}; | |||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||
op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||