Browse Source

add c_api.StringPiece to avoid crash for unmanaged memory.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
4fa14feeca
10 changed files with 38 additions and 42 deletions
  1. +12
    -22
      src/TensorFlowNET.Core/Operations/Operation.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Operations/ops.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  5. +3
    -3
      src/TensorFlowNET.Core/Status/c_api.status.cs
  6. +8
    -1
      src/TensorFlowNET.Core/c_api.cs
  7. +1
    -1
      src/TensorFlowNET.Core/tf.cs
  8. +2
    -5
      test/TensorFlowNET.UnitTest/GraphTest.cs
  9. +1
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  10. +4
    -2
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 12
- 22
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
@@ -14,16 +15,16 @@ namespace Tensorflow

private Status status = new Status();

public string name { get; }
public string optype { get; }
public string device { get; }
public int NumOutputs { get; }
public TF_DataType OutputType { get; }
public int OutputListLength { get; }
public int NumInputs { get; }
public int NumConsumers { get; }
public int NumControlInputs { get; }
public int NumControlOutputs { get; }
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0));
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0));
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

private Tensor[] _outputs;
public Tensor[] outputs => _outputs;
@@ -35,17 +36,6 @@ namespace Tensorflow
return;

_handle = handle;

name = c_api.TF_OperationName(_handle);
optype = c_api.TF_OperationOpType(_handle);
device = "";// c_api.TF_OperationDevice(_handle);
NumOutputs = c_api.TF_OperationNumOutputs(_handle);
OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0));
OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status);
NumInputs = c_api.TF_OperationNumInputs(_handle);
NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0));
NumControlInputs = c_api.TF_OperationNumControlInputs(_handle);
NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle);
}

public Operation(Graph g, string opType, string oper_name)
@@ -62,8 +52,8 @@ namespace Tensorflow
Graph = g;

_id_value = Graph._next_id();

_handle = ops._create_c_op(g, node_def, inputs);
NumOutputs = c_api.TF_OperationNumOutputs(_handle);

_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)


+ 3
- 3
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);

[DllImport(TensorFlowLibName)]
public static extern string TF_OperationDevice(IntPtr oper);
public static extern IntPtr TF_OperationDevice(IntPtr oper);

/// <summary>
/// Sets `output_attr_value` to the binary-serialized AttrValue proto
@@ -50,13 +50,13 @@ namespace Tensorflow
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern string TF_OperationName(IntPtr oper);
public static extern IntPtr TF_OperationName(IntPtr oper);

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationNumInputs(IntPtr oper);

[DllImport(TensorFlowLibName)]
public static extern string TF_OperationOpType(IntPtr oper);
public static extern IntPtr TF_OperationOpType(IntPtr oper);

/// <summary>
/// Get the number of control inputs to an operation.


+ 3
- 3
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -30,12 +30,12 @@ namespace Tensorflow
// Add inputs
if(inputs != null && inputs.Count > 0)
{
/*foreach (var op_input in inputs)
foreach (var op_input in inputs)
{
c_api.TF_AddInput(op_desc, op_input._as_tf_output());
}*/
}

c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
//c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
}

var status = new Status();


+ 1
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
/// <summary>
/// Error message
/// </summary>
public string Message => c_api.TF_Message(_handle);
public string Message => c_api.StringPiece(c_api.TF_Message(_handle));

/// <summary>
/// Error code


+ 3
- 3
src/TensorFlowNET.Core/Status/c_api.status.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// </summary>
/// <param name="s"></param>
[DllImport(TensorFlowLibName)]
public static unsafe extern void TF_DeleteStatus(IntPtr s);
public static extern void TF_DeleteStatus(IntPtr s);

/// <summary>
/// Return the code record in *s.
@@ -20,7 +20,7 @@ namespace Tensorflow
/// <param name="s"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern unsafe TF_Code TF_GetCode(IntPtr s);
public static extern TF_Code TF_GetCode(IntPtr s);

/// <summary>
/// Return a pointer to the (null-terminated) error message in *s.
@@ -30,7 +30,7 @@ namespace Tensorflow
/// <param name="s"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern unsafe string TF_Message(IntPtr s);
public static extern IntPtr TF_Message(IntPtr s);

/// <summary>
/// Return a new status object.


+ 8
- 1
src/TensorFlowNET.Core/c_api.cs View File

@@ -12,20 +12,27 @@ namespace Tensorflow
/// The API leans towards simplicity and uniformity instead of convenience
/// since most usage will be by language specific wrappers.
///
/// The params type mapping between .net and c_api
/// The params type mapping between c_api and .NET
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op)
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// struct* => struct (TF_Output* output) => (TF_Output[] output)
/// const char* => string
/// int32_t => int
/// int64_t* => long[]
/// size_t* => unlong[]
/// void* => IntPtr
/// string => IntPtr c_api.StringPiece(IntPtr)
/// </summary>
public static partial class c_api
{
public const string TensorFlowLibName = "tensorflow";

public static string StringPiece(IntPtr handle)
{
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator);

[DllImport(TensorFlowLibName)]


+ 1
- 1
src/TensorFlowNET.Core/tf.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
context.default_execution_mode = Context.EAGER_MODE;
}

public static string VERSION => Marshal.PtrToStringAnsi(c_api.TF_Version());
public static string VERSION => c_api.StringPiece(c_api.TF_Version());

public static Graph get_default_graph()
{


+ 2
- 5
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -39,17 +39,14 @@ namespace TensorFlowNET.UnitTest
// Test not found errors in TF_Operation*() query functions.
Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
//Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
//Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message);
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);

// Make a constant oper with the scalar "3".
var three = c_test_util.ScalarConst(3, graph, s);

// Add oper.
var add = c_test_util.Add(feed, three, graph, s);

NodeDef node_def = null;
c_test_util.GetNodeDef(feed, ref node_def);
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest
public void addInConstant()
{
var a = tf.constant(4.0f);
var b = tf.placeholder(tf.float32);
var b = tf.constant(5.0f);
var c = tf.add(a, b);

using (var sess = tf.Session())


+ 4
- 2
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -23,11 +23,13 @@ namespace TensorFlowNET.UnitTest
{
var desc = c_api.TF_NewOperation(graph, "AddN", name);

c_api.TF_AddInputList(desc, new TF_Output[]
var inputs = new TF_Output[]
{
new TF_Output(l, 0),
new TF_Output(r, 0),
}, 2);
};

c_api.TF_AddInputList(desc, inputs, inputs.Length);

op = c_api.TF_FinishOperation(desc, s);
s.Check();


Loading…
Cancel
Save