Browse Source

test

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
3536f6a2bd
11 changed files with 63 additions and 66 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  3. +14
    -14
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +4
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  5. +9
    -1
      src/TensorFlowNET.Core/Sessions/c_api.session.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +1
    -1
      src/TensorFlowNET.Core/tf.cs
  8. +4
    -4
      test/TensorFlowNET.UnitTest/CSession.cs
  9. +11
    -11
      test/TensorFlowNET.UnitTest/GraphTest.cs
  10. +7
    -7
      test/TensorFlowNET.UnitTest/TensorTest.cs
  11. +9
    -24
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 2
- 2
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -13,13 +13,13 @@ namespace Tensorflow
var num_return_outputs = opts.NumReturnOutputs; var num_return_outputs = opts.NumReturnOutputs;
var return_outputs = new TF_Output[num_return_outputs]; var return_outputs = new TF_Output[num_return_outputs];
int size = Marshal.SizeOf<TF_Output>(); int size = Marshal.SizeOf<TF_Output>();
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(size * num_return_outputs);
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);


c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
for (int i = 0; i < num_return_outputs; i++) for (int i = 0; i < num_return_outputs; i++)
{ {
var handle = return_output_handle + i * size; var handle = return_output_handle + i * size;
return_outputs[i] = new TF_Output((*handle).oper, (*handle).index);
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
} }


return return_outputs; return return_outputs;


+ 1
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
/// <param name="num_return_outputs">int</param> /// <param name="num_return_outputs">int</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status);
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, IntPtr status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and


+ 14
- 14
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -9,12 +9,12 @@ namespace Tensorflow
{ {
public class BaseSession : IDisposable public class BaseSession : IDisposable
{ {
private Graph _graph;
private bool _opened;
private bool _closed;
private int _current_version;
private byte[] _target;
private IntPtr _session;
protected Graph _graph;
protected bool _opened;
protected bool _closed;
protected int _current_version;
protected byte[] _target;
protected IntPtr _session;


public BaseSession(string target = "", Graph graph = null) public BaseSession(string target = "", Graph graph = null)
{ {
@@ -28,11 +28,11 @@ namespace Tensorflow
} }


_target = UTF8Encoding.UTF8.GetBytes(target); _target = UTF8Encoding.UTF8.GetBytes(target);
//var opts = c_api.TF_NewSessionOptions();
//var status = new Status();
//_session = c_api.TF_NewSession(_graph, opts, status);
var opts = c_api.TF_NewSessionOptions();
var status = new Status();
_session = c_api.TF_NewSession(_graph, opts, status);


//c_api.TF_DeleteSessionOptions(opts);
c_api.TF_DeleteSessionOptions(opts);
} }


public void Dispose() public void Dispose()
@@ -102,18 +102,18 @@ namespace Tensorflow


var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();


/*c_api.TF_SessionRun(_session,
run_options: IntPtr.Zero,
c_api.TF_SessionRun(_session,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(), inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: new IntPtr[] { }, input_values: new IntPtr[] { },
ninputs: 0, ninputs: 0,
outputs: fetch_list, outputs: fetch_list,
output_values: output_values, output_values: output_values,
noutputs: fetch_list.Length, noutputs: fetch_list.Length,
target_opers: new IntPtr[] { },
target_opers: IntPtr.Zero,
ntargets: 0, ntargets: 0,
run_metadata: IntPtr.Zero, run_metadata: IntPtr.Zero,
status: status);*/
status: status);


var result = output_values.Select(x => c_api.TF_TensorData(x)) var result = output_values.Select(x => c_api.TF_TensorData(x))
.Select(x => (object)*(float*)x) .Select(x => (object)*(float*)x)


+ 4
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -8,6 +8,10 @@ namespace Tensorflow
{ {
private IntPtr _handle; private IntPtr _handle;


public Session(string target = "", Graph graph = null)
{
}

public Session(IntPtr handle) public Session(IntPtr handle)
{ {
_handle = handle; _handle = handle;


+ 9
- 1
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -70,10 +70,18 @@ namespace Tensorflow
/// <param name="ntargets">int</param> /// <param name="ntargets">int</param>
/// <param name="run_metadata">TF_Buffer*</param> /// <param name="run_metadata">TF_Buffer*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
TF_Output[] inputs, IntPtr[] input_values, int ninputs,
TF_Output[] outputs, IntPtr[] output_values, int noutputs,
IntPtr target_opers, int ntargets,
IntPtr run_metadata,
IntPtr status);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
IntPtr inputs, IntPtr input_values, int ninputs, IntPtr inputs, IntPtr input_values, int ninputs,
IntPtr outputs, ref IntPtr output_values, int noutputs,
IntPtr outputs, IntPtr[] output_values, int noutputs,
IntPtr target_opers, int ntargets, IntPtr target_opers, int ntargets,
IntPtr run_metadata, IntPtr run_metadata,
IntPtr status); IntPtr status);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow
{ {
var dims = new long[rank]; var dims = new long[rank];
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(_handle, i);
dims[i] = c_api.TF_Dim(_handle, i);


return dims; return dims;
} }


+ 1
- 1
src/TensorFlowNET.Core/tf.cs View File

@@ -51,7 +51,7 @@ namespace Tensorflow


public static Session Session() public static Session Session()
{ {
return (Session)new BaseSession();
return new Session();
} }
} }
} }

+ 4
- 4
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -79,17 +79,17 @@ namespace TensorFlowNET.UnitTest
IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
IntPtr[] output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
IntPtr targets_ptr = IntPtr.Zero; IntPtr targets_ptr = IntPtr.Zero;


c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count,
outputs_ptr, ref output_values_ptr, outputs_.Count,
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 0,
outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count, targets_ptr, targets_.Count,
IntPtr.Zero, s); IntPtr.Zero, s);


s.Check(); s.Check();


output_values_[0] = output_values_ptr;
output_values_[0] = output_values_ptr[0];
} }


public IntPtr output_tensor(int i) public IntPtr output_tensor(int i)


+ 11
- 11
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -77,7 +77,7 @@ namespace TensorFlowNET.UnitTest
ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
EXPECT_EQ(DataType.DtInt32, attr_value.Type); EXPECT_EQ(DataType.DtInt32, attr_value.Type);
ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
EXPECT_EQ(2, attr_value.I);
EXPECT_EQ(2, (int)attr_value.I);


// Placeholder oper now has a consumer. // Placeholder oper now has a consumer.
EXPECT_EQ(1, feed.OutputNumConsumers(0)); EXPECT_EQ(1, feed.OutputNumConsumers(0));
@@ -353,7 +353,7 @@ namespace TensorFlowNET.UnitTest
c_test_util.Add(feed, scalar, graph, s); c_test_util.Add(feed, scalar, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);


//graph.Dispose();
graph.Dispose();
s.Dispose(); s.Dispose();
} }


@@ -368,12 +368,12 @@ namespace TensorFlowNET.UnitTest
var graph = new Graph(); var graph = new Graph();


// Create a graph with two nodes: x and 3 // Create a graph with two nodes: x and 3
var feed = c_test_util.Placeholder(graph, s);
EXPECT_EQ(feed, graph.OperationByName("feed"));
var scalar = c_test_util.ScalarConst(3, graph, s);
EXPECT_EQ(scalar, graph.OperationByName("scalar"));
var neg = c_test_util.Neg(scalar, graph, s);
EXPECT_EQ(neg, graph.OperationByName("neg"));
c_test_util.Placeholder(graph, s);
ASSERT_TRUE(graph.OperationByName("feed") != null);
var oper = c_test_util.ScalarConst(3, graph, s);
ASSERT_TRUE(graph.OperationByName("scalar") != null);
c_test_util.Neg(oper, graph, s);
ASSERT_TRUE(graph.OperationByName("neg") != null);


// Export to a GraphDef. // Export to a GraphDef.
var graph_def = graph.ToGraphDef(s); var graph_def = graph.ToGraphDef(s);
@@ -389,9 +389,9 @@ namespace TensorFlowNET.UnitTest
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);


scalar = graph.OperationByName("scalar");
feed = graph.OperationByName("feed");
neg = graph.OperationByName("neg");
var scalar = graph.OperationByName("scalar");
var feed = graph.OperationByName("feed");
var neg = graph.OperationByName("neg");
ASSERT_TRUE(scalar != IntPtr.Zero); ASSERT_TRUE(scalar != IntPtr.Zero);
ASSERT_TRUE(feed != IntPtr.Zero); ASSERT_TRUE(feed != IntPtr.Zero);
ASSERT_TRUE(neg != IntPtr.Zero); ASSERT_TRUE(neg != IntPtr.Zero);


+ 7
- 7
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -57,9 +57,9 @@ namespace TensorFlowNET.UnitTest


EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim); EXPECT_EQ(tensor.rank, nd.ndim);
EXPECT_EQ(tensor.shape[0], nd.shape[0]);
EXPECT_EQ(tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float));
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
} }


@@ -118,8 +118,8 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);


// Try to set 'unknown' with same rank on the shape and see that // Try to set 'unknown' with same rank on the shape and see that
// it doesn't change. // it doesn't change.
@@ -130,8 +130,8 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);


// Try to fetch a shape with the wrong num_dims // Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);


+ 9
- 24
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -13,13 +13,6 @@ namespace TensorFlowNET.UnitTest
public static class c_test_util public static class c_test_util
{ {
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add")
{
Operation op = null;
AddOpHelper(l, r, graph, s, name, ref op, true);
return op;
}

public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check)
{ {
var desc = c_api.TF_NewOperation(graph, "AddN", name); var desc = c_api.TF_NewOperation(graph, "AddN", name);


@@ -31,8 +24,10 @@ namespace TensorFlowNET.UnitTest


c_api.TF_AddInputList(desc, inputs, inputs.Length); c_api.TF_AddInputList(desc, inputs, inputs.Length);


op = c_api.TF_FinishOperation(desc, s);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();

return op;
} }


public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
@@ -196,39 +191,29 @@ namespace TensorFlowNET.UnitTest
return op; return op;
} }


public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null)
{ {
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype); c_api.TF_SetAttrType(desc, "dtype", dtype);
if(dims != null)
if (dims != null)
{ {
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length);
} }
op = c_api.TF_FinishOperation(desc, s);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();
}


public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null)
{
Operation op = null;
PlaceholderHelper(graph, s, name, dtype, dims, ref op);
return op; return op;
} }


public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op)
public static Operation Const(Tensor t, Graph graph, Status s, string name)
{ {
var desc = c_api.TF_NewOperation(graph, "Const", name); var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t, s); c_api.TF_SetAttrTensor(desc, "value", t, s);
s.Check(); s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype); c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();
}


public static Operation Const(Tensor t, Graph graph, Status s, string name)
{
Operation op = null;
ConstHelper(t, graph, s, name, ref op);
return op; return op;
} }


@@ -247,7 +232,7 @@ namespace TensorFlowNET.UnitTest
(IntPtr values, IntPtr len, ref bool closure) => (IntPtr values, IntPtr len, ref bool closure) =>
{ {
// Free the original buffer and set flag // Free the original buffer and set flag
Marshal.FreeHGlobal(dotHandle);
// Marshal.FreeHGlobal(dotHandle);
}, ref deallocator_called); }, ref deallocator_called);
} }
} }


Loading…
Cancel
Save