@@ -13,13 +13,13 @@ namespace Tensorflow | |||
var num_return_outputs = opts.NumReturnOutputs; | |||
var return_outputs = new TF_Output[num_return_outputs]; | |||
int size = Marshal.SizeOf<TF_Output>(); | |||
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(size * num_return_outputs); | |||
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
for (int i = 0; i < num_return_outputs; i++) | |||
{ | |||
var handle = return_output_handle + i * size; | |||
return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); | |||
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
} | |||
return return_outputs; | |||
@@ -61,7 +61,7 @@ namespace Tensorflow | |||
/// <param name="num_return_outputs">int</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); | |||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, IntPtr status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
@@ -9,12 +9,12 @@ namespace Tensorflow | |||
{ | |||
public class BaseSession : IDisposable | |||
{ | |||
private Graph _graph; | |||
private bool _opened; | |||
private bool _closed; | |||
private int _current_version; | |||
private byte[] _target; | |||
private IntPtr _session; | |||
protected Graph _graph; | |||
protected bool _opened; | |||
protected bool _closed; | |||
protected int _current_version; | |||
protected byte[] _target; | |||
protected IntPtr _session; | |||
public BaseSession(string target = "", Graph graph = null) | |||
{ | |||
@@ -28,11 +28,11 @@ namespace Tensorflow | |||
} | |||
_target = UTF8Encoding.UTF8.GetBytes(target); | |||
//var opts = c_api.TF_NewSessionOptions(); | |||
//var status = new Status(); | |||
//_session = c_api.TF_NewSession(_graph, opts, status); | |||
var opts = c_api.TF_NewSessionOptions(); | |||
var status = new Status(); | |||
_session = c_api.TF_NewSession(_graph, opts, status); | |||
//c_api.TF_DeleteSessionOptions(opts); | |||
c_api.TF_DeleteSessionOptions(opts); | |||
} | |||
public void Dispose() | |||
@@ -102,18 +102,18 @@ namespace Tensorflow | |||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
/*c_api.TF_SessionRun(_session, | |||
run_options: IntPtr.Zero, | |||
c_api.TF_SessionRun(_session, | |||
run_options: null, | |||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
input_values: new IntPtr[] { }, | |||
ninputs: 0, | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: fetch_list.Length, | |||
target_opers: new IntPtr[] { }, | |||
target_opers: IntPtr.Zero, | |||
ntargets: 0, | |||
run_metadata: IntPtr.Zero, | |||
status: status);*/ | |||
status: status); | |||
var result = output_values.Select(x => c_api.TF_TensorData(x)) | |||
.Select(x => (object)*(float*)x) | |||
@@ -8,6 +8,10 @@ namespace Tensorflow | |||
{ | |||
private IntPtr _handle; | |||
public Session(string target = "", Graph graph = null) | |||
{ | |||
} | |||
public Session(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
@@ -70,10 +70,18 @@ namespace Tensorflow | |||
/// <param name="ntargets">int</param> | |||
/// <param name="run_metadata">TF_Buffer*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||
IntPtr target_opers, int ntargets, | |||
IntPtr run_metadata, | |||
IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||
IntPtr inputs, IntPtr input_values, int ninputs, | |||
IntPtr outputs, ref IntPtr output_values, int noutputs, | |||
IntPtr outputs, IntPtr[] output_values, int noutputs, | |||
IntPtr target_opers, int ntargets, | |||
IntPtr run_metadata, | |||
IntPtr status); | |||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||
{ | |||
var dims = new long[rank]; | |||
for (int i = 0; i < rank; i++) | |||
shape[i] = c_api.TF_Dim(_handle, i); | |||
dims[i] = c_api.TF_Dim(_handle, i); | |||
return dims; | |||
} | |||
@@ -51,7 +51,7 @@ namespace Tensorflow | |||
public static Session Session() | |||
{ | |||
return (Session)new BaseSession(); | |||
return new Session(); | |||
} | |||
} | |||
} |
@@ -79,17 +79,17 @@ namespace TensorFlowNET.UnitTest | |||
IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; | |||
IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; | |||
IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; | |||
IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; | |||
IntPtr[] output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; | |||
IntPtr targets_ptr = IntPtr.Zero; | |||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count, | |||
outputs_ptr, ref output_values_ptr, outputs_.Count, | |||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 0, | |||
outputs_ptr, output_values_ptr, outputs_.Count, | |||
targets_ptr, targets_.Count, | |||
IntPtr.Zero, s); | |||
s.Check(); | |||
output_values_[0] = output_values_ptr; | |||
output_values_[0] = output_values_ptr[0]; | |||
} | |||
public IntPtr output_tensor(int i) | |||
@@ -77,7 +77,7 @@ namespace TensorFlowNET.UnitTest | |||
ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||
EXPECT_EQ(DataType.DtInt32, attr_value.Type); | |||
ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||
EXPECT_EQ(2, attr_value.I); | |||
EXPECT_EQ(2, (int)attr_value.I); | |||
// Placeholder oper now has a consumer. | |||
EXPECT_EQ(1, feed.OutputNumConsumers(0)); | |||
@@ -353,7 +353,7 @@ namespace TensorFlowNET.UnitTest | |||
c_test_util.Add(feed, scalar, graph, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
//graph.Dispose(); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
@@ -368,12 +368,12 @@ namespace TensorFlowNET.UnitTest | |||
var graph = new Graph(); | |||
// Create a graph with two nodes: x and 3 | |||
var feed = c_test_util.Placeholder(graph, s); | |||
EXPECT_EQ(feed, graph.OperationByName("feed")); | |||
var scalar = c_test_util.ScalarConst(3, graph, s); | |||
EXPECT_EQ(scalar, graph.OperationByName("scalar")); | |||
var neg = c_test_util.Neg(scalar, graph, s); | |||
EXPECT_EQ(neg, graph.OperationByName("neg")); | |||
c_test_util.Placeholder(graph, s); | |||
ASSERT_TRUE(graph.OperationByName("feed") != null); | |||
var oper = c_test_util.ScalarConst(3, graph, s); | |||
ASSERT_TRUE(graph.OperationByName("scalar") != null); | |||
c_test_util.Neg(oper, graph, s); | |||
ASSERT_TRUE(graph.OperationByName("neg") != null); | |||
// Export to a GraphDef. | |||
var graph_def = graph.ToGraphDef(s); | |||
@@ -389,9 +389,9 @@ namespace TensorFlowNET.UnitTest | |||
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
scalar = graph.OperationByName("scalar"); | |||
feed = graph.OperationByName("feed"); | |||
neg = graph.OperationByName("neg"); | |||
var scalar = graph.OperationByName("scalar"); | |||
var feed = graph.OperationByName("feed"); | |||
var neg = graph.OperationByName("neg"); | |||
ASSERT_TRUE(scalar != IntPtr.Zero); | |||
ASSERT_TRUE(feed != IntPtr.Zero); | |||
ASSERT_TRUE(neg != IntPtr.Zero); | |||
@@ -57,9 +57,9 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||
EXPECT_EQ(tensor.rank, nd.ndim); | |||
EXPECT_EQ(tensor.shape[0], nd.shape[0]); | |||
EXPECT_EQ(tensor.shape[1], nd.shape[1]); | |||
EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||
} | |||
@@ -118,8 +118,8 @@ namespace TensorFlowNET.UnitTest | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
EXPECT_EQ(2, num_dims); | |||
EXPECT_EQ(2, returned_dims[0]); | |||
EXPECT_EQ(3, returned_dims[1]); | |||
EXPECT_EQ(2, (int)returned_dims[0]); | |||
EXPECT_EQ(3, (int)returned_dims[1]); | |||
// Try to set 'unknown' with same rank on the shape and see that | |||
// it doesn't change. | |||
@@ -130,8 +130,8 @@ namespace TensorFlowNET.UnitTest | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
EXPECT_EQ(2, num_dims); | |||
EXPECT_EQ(2, returned_dims[0]); | |||
EXPECT_EQ(3, returned_dims[1]); | |||
EXPECT_EQ(2, (int)returned_dims[0]); | |||
EXPECT_EQ(3, (int)returned_dims[1]); | |||
// Try to fetch a shape with the wrong num_dims | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||
@@ -13,13 +13,6 @@ namespace TensorFlowNET.UnitTest | |||
public static class c_test_util | |||
{ | |||
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | |||
{ | |||
Operation op = null; | |||
AddOpHelper(l, r, graph, s, name, ref op, true); | |||
return op; | |||
} | |||
public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||
@@ -31,8 +24,10 @@ namespace TensorFlowNET.UnitTest | |||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||
op = c_api.TF_FinishOperation(desc, s); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
} | |||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
@@ -196,39 +191,29 @@ namespace TensorFlowNET.UnitTest | |||
return op; | |||
} | |||
public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||
if(dims != null) | |||
if (dims != null) | |||
{ | |||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||
} | |||
op = c_api.TF_FinishOperation(desc, s); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
} | |||
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||
{ | |||
Operation op = null; | |||
PlaceholderHelper(graph, s, name, dtype, dims, ref op); | |||
return op; | |||
} | |||
public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op) | |||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||
s.Check(); | |||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
op = c_api.TF_FinishOperation(desc, s); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
} | |||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||
{ | |||
Operation op = null; | |||
ConstHelper(t, graph, s, name, ref op); | |||
return op; | |||
} | |||
@@ -247,7 +232,7 @@ namespace TensorFlowNET.UnitTest | |||
(IntPtr values, IntPtr len, ref bool closure) => | |||
{ | |||
// Free the original buffer and set flag | |||
Marshal.FreeHGlobal(dotHandle); | |||
// Marshal.FreeHGlobal(dotHandle); | |||
}, ref deallocator_called); | |||
} | |||
} | |||