|
|
@@ -7,6 +7,7 @@ using Tensorflow; |
|
|
|
using tf = TensorFlowNET.Core.Tensorflow; |
|
|
|
using TF_DataType = Tensorflow.DataType; |
|
|
|
using node_def_pb2 = Tensorflow; |
|
|
|
using Google.Protobuf; |
|
|
|
|
|
|
|
namespace TensorFlowNET.Core |
|
|
|
{ |
|
|
@@ -20,49 +21,17 @@ namespace TensorFlowNET.Core |
|
|
|
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) |
|
|
|
{ |
|
|
|
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name); |
|
|
|
var status = c_api.TF_NewStatus(); |
|
|
|
|
|
|
|
// Doesn't work |
|
|
|
/*foreach(var attr in node_def.Attr) |
|
|
|
{ |
|
|
|
if (attr.Value.Tensor != null) |
|
|
|
{ |
|
|
|
switch (attr.Value.Tensor.Dtype) |
|
|
|
{ |
|
|
|
case DataType.DtDouble: |
|
|
|
var proto = (double*)Marshal.AllocHGlobal(sizeof(double)); |
|
|
|
*proto = attr.Value.Tensor.DoubleVal[0]; |
|
|
|
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
//c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status); |
|
|
|
} |
|
|
|
} */ |
|
|
|
var status = new Status(); |
|
|
|
|
|
|
|
foreach (var attr in node_def.Attr) |
|
|
|
{ |
|
|
|
if (attr.Value.Tensor == null) continue; |
|
|
|
switch (attr.Value.Tensor.Dtype) |
|
|
|
{ |
|
|
|
case DataType.DtDouble: |
|
|
|
var v = (double*)Marshal.AllocHGlobal(sizeof(double)); |
|
|
|
*v = attr.Value.Tensor.DoubleVal[0]; |
|
|
|
var tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); |
|
|
|
c_api.TF_SetAttrTensor(op_desc, "value", tensor, status); |
|
|
|
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble); |
|
|
|
break; |
|
|
|
case DataType.DtString: |
|
|
|
|
|
|
|
var proto = Marshal.StringToHGlobalAnsi(attr.Value.Tensor.StringVal[0].ToStringUtf8()); |
|
|
|
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto.ToPointer(), proto_len: (UIntPtr)32, status: status); |
|
|
|
break; |
|
|
|
} |
|
|
|
var bytes = attr.Value.ToByteArray(); |
|
|
|
var proto = Marshal.AllocHGlobal(bytes.Length); |
|
|
|
Marshal.Copy(bytes, 0, proto, bytes.Length); |
|
|
|
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); |
|
|
|
} |
|
|
|
|
|
|
|
var c_op = c_api.TF_FinishOperation(op_desc, status); |
|
|
|
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); |
|
|
|
|
|
|
|
return c_op; |
|
|
|
} |
|
|
|