Browse Source

added TF_Status, TF_SetAttrValueProto

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
7b24f538fb
7 changed files with 67 additions and 43 deletions
  1. +0
    -1
      .gitignore
  2. +22
    -0
      src/TensorFlowNET.Core/Status.cs
  3. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +27
    -0
      src/TensorFlowNET.Core/Tensorflow/TF_Code.cs
  5. +8
    -2
      src/TensorFlowNET.Core/c_api.cs
  6. +7
    -38
      src/TensorFlowNET.Core/ops.cs
  7. +2
    -1
      test/TensorFlowNET.Examples/HelloWorld.cs

+ 0
- 1
.gitignore View File

@@ -333,5 +333,4 @@ ASALocalRun/
/tensorflowlib/osx/native/libtensorflow.dylib
/tensorflowlib/linux/native/libtensorflow_framework.so
/tensorflowlib/linux/native/libtensorflow.so
/src/TensorFlowNET.Core/libtensorflow.dll
/src/TensorFlowNET.Core/tensorflow.dll

+ 22
- 0
src/TensorFlowNET.Core/Status.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Core
{
public class Status
{
private IntPtr _handle;
public IntPtr Handle => _handle;

public string ErrorMessage => c_api.TF_Message(_handle);

public TF_Code Code => c_api.TF_GetCode(_handle);

public Status()
{
_handle = c_api.TF_NewStatus();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -22,7 +22,7 @@
</ItemGroup>

<ItemGroup>
<None Update="libtensorflow.dll">
<None Update="tensorflow.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>


+ 27
- 0
src/TensorFlowNET.Core/Tensorflow/TF_Code.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public enum TF_Code
{
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15
}
}

+ 8
- 2
src/TensorFlowNET.Core/c_api.cs View File

@@ -11,7 +11,7 @@ using TF_Status = System.IntPtr;
using TF_Tensor = System.IntPtr;

using TF_DataType = Tensorflow.DataType;
using Tensorflow;
using static TensorFlowNET.Core.Tensorflow;

namespace TensorFlowNET.Core
@@ -23,6 +23,12 @@ namespace TensorFlowNET.Core
[DllImport(TensorFlowLibName)]
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status);

[DllImport(TensorFlowLibName)]
public static extern unsafe TF_Code TF_GetCode(TF_Status s);

[DllImport(TensorFlowLibName)]
public static extern unsafe string TF_Message(TF_Status s);

[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_NewGraph();

@@ -39,7 +45,7 @@ namespace TensorFlowNET.Core
public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, void* proto, size_t proto_len, TF_Status status);
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, size_t proto_len, TF_Status status);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status);


+ 7
- 38
src/TensorFlowNET.Core/ops.cs View File

@@ -7,6 +7,7 @@ using Tensorflow;
using tf = TensorFlowNET.Core.Tensorflow;
using TF_DataType = Tensorflow.DataType;
using node_def_pb2 = Tensorflow;
using Google.Protobuf;

namespace TensorFlowNET.Core
{
@@ -20,49 +21,17 @@ namespace TensorFlowNET.Core
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
{
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name);
var status = c_api.TF_NewStatus();

// Doesn't work
/*foreach(var attr in node_def.Attr)
{
if (attr.Value.Tensor != null)
{
switch (attr.Value.Tensor.Dtype)
{
case DataType.DtDouble:
var proto = (double*)Marshal.AllocHGlobal(sizeof(double));
*proto = attr.Value.Tensor.DoubleVal[0];
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status);
break;
}
}
else
{
//c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status);
}
} */
var status = new Status();

foreach (var attr in node_def.Attr)
{
if (attr.Value.Tensor == null) continue;
switch (attr.Value.Tensor.Dtype)
{
case DataType.DtDouble:
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
*v = attr.Value.Tensor.DoubleVal[0];
var tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
c_api.TF_SetAttrTensor(op_desc, "value", tensor, status);
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble);
break;
case DataType.DtString:
var proto = Marshal.StringToHGlobalAnsi(attr.Value.Tensor.StringVal[0].ToStringUtf8());
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto.ToPointer(), proto_len: (UIntPtr)32, status: status);
break;
}
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);
}

var c_op = c_api.TF_FinishOperation(op_desc, status);
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);

return c_op;
}


+ 2
- 1
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -19,7 +19,8 @@ namespace TensorFlowNET.Examples
The value returned by the constructor represents the output
of the Constant op.*/
var graph = tf.get_default_graph();
var hello = tf.constant("Hello, TensorFlow!");
var hello = tf.constant(4.0);
//var hello = tf.constant("Hello, TensorFlow!");

// Start tf session
// var sess = tf.Session();


Loading…
Cancel
Save