Browse Source

test

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
3536f6a2bd
11 changed files with 63 additions and 66 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  3. +14
    -14
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +4
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  5. +9
    -1
      src/TensorFlowNET.Core/Sessions/c_api.session.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +1
    -1
      src/TensorFlowNET.Core/tf.cs
  8. +4
    -4
      test/TensorFlowNET.UnitTest/CSession.cs
  9. +11
    -11
      test/TensorFlowNET.UnitTest/GraphTest.cs
  10. +7
    -7
      test/TensorFlowNET.UnitTest/TensorTest.cs
  11. +9
    -24
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 2
- 2
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -13,13 +13,13 @@ namespace Tensorflow
var num_return_outputs = opts.NumReturnOutputs;
var return_outputs = new TF_Output[num_return_outputs];
int size = Marshal.SizeOf<TF_Output>();
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(size * num_return_outputs);
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);

c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
for (int i = 0; i < num_return_outputs; i++)
{
var handle = return_output_handle + i * size;
return_outputs[i] = new TF_Output((*handle).oper, (*handle).index);
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
}

return return_outputs;


+ 1
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
/// <param name="num_return_outputs">int</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status);
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, IntPtr status);

/// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and


+ 14
- 14
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -9,12 +9,12 @@ namespace Tensorflow
{
public class BaseSession : IDisposable
{
private Graph _graph;
private bool _opened;
private bool _closed;
private int _current_version;
private byte[] _target;
private IntPtr _session;
protected Graph _graph;
protected bool _opened;
protected bool _closed;
protected int _current_version;
protected byte[] _target;
protected IntPtr _session;

public BaseSession(string target = "", Graph graph = null)
{
@@ -28,11 +28,11 @@ namespace Tensorflow
}

_target = UTF8Encoding.UTF8.GetBytes(target);
//var opts = c_api.TF_NewSessionOptions();
//var status = new Status();
//_session = c_api.TF_NewSession(_graph, opts, status);
var opts = c_api.TF_NewSessionOptions();
var status = new Status();
_session = c_api.TF_NewSession(_graph, opts, status);

//c_api.TF_DeleteSessionOptions(opts);
c_api.TF_DeleteSessionOptions(opts);
}

public void Dispose()
@@ -102,18 +102,18 @@ namespace Tensorflow

var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();

/*c_api.TF_SessionRun(_session,
run_options: IntPtr.Zero,
c_api.TF_SessionRun(_session,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: new IntPtr[] { },
ninputs: 0,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: new IntPtr[] { },
target_opers: IntPtr.Zero,
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status);*/
status: status);

var result = output_values.Select(x => c_api.TF_TensorData(x))
.Select(x => (object)*(float*)x)


+ 4
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -8,6 +8,10 @@ namespace Tensorflow
{
private IntPtr _handle;

public Session(string target = "", Graph graph = null)
{
}

public Session(IntPtr handle)
{
_handle = handle;


+ 9
- 1
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -70,10 +70,18 @@ namespace Tensorflow
/// <param name="ntargets">int</param>
/// <param name="run_metadata">TF_Buffer*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
TF_Output[] inputs, IntPtr[] input_values, int ninputs,
TF_Output[] outputs, IntPtr[] output_values, int noutputs,
IntPtr target_opers, int ntargets,
IntPtr run_metadata,
IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
IntPtr inputs, IntPtr input_values, int ninputs,
IntPtr outputs, ref IntPtr output_values, int noutputs,
IntPtr outputs, IntPtr[] output_values, int noutputs,
IntPtr target_opers, int ntargets,
IntPtr run_metadata,
IntPtr status);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow
{
var dims = new long[rank];
for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(_handle, i);
dims[i] = c_api.TF_Dim(_handle, i);

return dims;
}


+ 1
- 1
src/TensorFlowNET.Core/tf.cs View File

@@ -51,7 +51,7 @@ namespace Tensorflow

public static Session Session()
{
return (Session)new BaseSession();
return new Session();
}
}
}

+ 4
- 4
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -79,17 +79,17 @@ namespace TensorFlowNET.UnitTest
IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
IntPtr[] output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
IntPtr targets_ptr = IntPtr.Zero;

c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count,
outputs_ptr, ref output_values_ptr, outputs_.Count,
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 0,
outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count,
IntPtr.Zero, s);

s.Check();

output_values_[0] = output_values_ptr;
output_values_[0] = output_values_ptr[0];
}

public IntPtr output_tensor(int i)


+ 11
- 11
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -77,7 +77,7 @@ namespace TensorFlowNET.UnitTest
ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
EXPECT_EQ(DataType.DtInt32, attr_value.Type);
ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
EXPECT_EQ(2, attr_value.I);
EXPECT_EQ(2, (int)attr_value.I);

// Placeholder oper now has a consumer.
EXPECT_EQ(1, feed.OutputNumConsumers(0));
@@ -353,7 +353,7 @@ namespace TensorFlowNET.UnitTest
c_test_util.Add(feed, scalar, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

//graph.Dispose();
graph.Dispose();
s.Dispose();
}

@@ -368,12 +368,12 @@ namespace TensorFlowNET.UnitTest
var graph = new Graph();

// Create a graph with two nodes: x and 3
var feed = c_test_util.Placeholder(graph, s);
EXPECT_EQ(feed, graph.OperationByName("feed"));
var scalar = c_test_util.ScalarConst(3, graph, s);
EXPECT_EQ(scalar, graph.OperationByName("scalar"));
var neg = c_test_util.Neg(scalar, graph, s);
EXPECT_EQ(neg, graph.OperationByName("neg"));
c_test_util.Placeholder(graph, s);
ASSERT_TRUE(graph.OperationByName("feed") != null);
var oper = c_test_util.ScalarConst(3, graph, s);
ASSERT_TRUE(graph.OperationByName("scalar") != null);
c_test_util.Neg(oper, graph, s);
ASSERT_TRUE(graph.OperationByName("neg") != null);

// Export to a GraphDef.
var graph_def = graph.ToGraphDef(s);
@@ -389,9 +389,9 @@ namespace TensorFlowNET.UnitTest
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

scalar = graph.OperationByName("scalar");
feed = graph.OperationByName("feed");
neg = graph.OperationByName("neg");
var scalar = graph.OperationByName("scalar");
var feed = graph.OperationByName("feed");
var neg = graph.OperationByName("neg");
ASSERT_TRUE(scalar != IntPtr.Zero);
ASSERT_TRUE(feed != IntPtr.Zero);
ASSERT_TRUE(neg != IntPtr.Zero);


+ 7
- 7
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -57,9 +57,9 @@ namespace TensorFlowNET.UnitTest

EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);
EXPECT_EQ(tensor.shape[0], nd.shape[0]);
EXPECT_EQ(tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float));
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
}

@@ -118,8 +118,8 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);

// Try to set 'unknown' with same rank on the shape and see that
// it doesn't change.
@@ -130,8 +130,8 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);

// Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);


+ 9
- 24
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -13,13 +13,6 @@ namespace TensorFlowNET.UnitTest
public static class c_test_util
{
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add")
{
Operation op = null;
AddOpHelper(l, r, graph, s, name, ref op, true);
return op;
}

public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check)
{
var desc = c_api.TF_NewOperation(graph, "AddN", name);

@@ -31,8 +24,10 @@ namespace TensorFlowNET.UnitTest

c_api.TF_AddInputList(desc, inputs, inputs.Length);

op = c_api.TF_FinishOperation(desc, s);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();

return op;
}

public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
@@ -196,39 +191,29 @@ namespace TensorFlowNET.UnitTest
return op;
}

public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
if(dims != null)
if (dims != null)
{
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length);
}
op = c_api.TF_FinishOperation(desc, s);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();
}

public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null)
{
Operation op = null;
PlaceholderHelper(graph, s, name, dtype, dims, ref op);
return op;
}

public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op)
public static Operation Const(Tensor t, Graph graph, Status s, string name)
{
var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t, s);
s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();
}

public static Operation Const(Tensor t, Graph graph, Status s, string name)
{
Operation op = null;
ConstHelper(t, graph, s, name, ref op);
return op;
}

@@ -247,7 +232,7 @@ namespace TensorFlowNET.UnitTest
(IntPtr values, IntPtr len, ref bool closure) =>
{
// Free the original buffer and set flag
Marshal.FreeHGlobal(dotHandle);
// Marshal.FreeHGlobal(dotHandle);
}, ref deallocator_called);
}
}


Loading…
Cancel
Save