@@ -13,13 +13,13 @@ namespace Tensorflow | |||||
var num_return_outputs = opts.NumReturnOutputs; | var num_return_outputs = opts.NumReturnOutputs; | ||||
var return_outputs = new TF_Output[num_return_outputs]; | var return_outputs = new TF_Output[num_return_outputs]; | ||||
int size = Marshal.SizeOf<TF_Output>(); | int size = Marshal.SizeOf<TF_Output>(); | ||||
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(size * num_return_outputs); | |||||
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | ||||
for (int i = 0; i < num_return_outputs; i++) | for (int i = 0; i < num_return_outputs; i++) | ||||
{ | { | ||||
var handle = return_output_handle + i * size; | var handle = return_output_handle + i * size; | ||||
return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); | |||||
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||||
} | } | ||||
return return_outputs; | return return_outputs; | ||||
@@ -61,7 +61,7 @@ namespace Tensorflow | |||||
/// <param name="num_return_outputs">int</param> | /// <param name="num_return_outputs">int</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); | |||||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | ||||
@@ -9,12 +9,12 @@ namespace Tensorflow | |||||
{ | { | ||||
public class BaseSession : IDisposable | public class BaseSession : IDisposable | ||||
{ | { | ||||
private Graph _graph; | |||||
private bool _opened; | |||||
private bool _closed; | |||||
private int _current_version; | |||||
private byte[] _target; | |||||
private IntPtr _session; | |||||
protected Graph _graph; | |||||
protected bool _opened; | |||||
protected bool _closed; | |||||
protected int _current_version; | |||||
protected byte[] _target; | |||||
protected IntPtr _session; | |||||
public BaseSession(string target = "", Graph graph = null) | public BaseSession(string target = "", Graph graph = null) | ||||
{ | { | ||||
@@ -28,11 +28,11 @@ namespace Tensorflow | |||||
} | } | ||||
_target = UTF8Encoding.UTF8.GetBytes(target); | _target = UTF8Encoding.UTF8.GetBytes(target); | ||||
//var opts = c_api.TF_NewSessionOptions(); | |||||
//var status = new Status(); | |||||
//_session = c_api.TF_NewSession(_graph, opts, status); | |||||
var opts = c_api.TF_NewSessionOptions(); | |||||
var status = new Status(); | |||||
_session = c_api.TF_NewSession(_graph, opts, status); | |||||
//c_api.TF_DeleteSessionOptions(opts); | |||||
c_api.TF_DeleteSessionOptions(opts); | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -102,18 +102,18 @@ namespace Tensorflow | |||||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | ||||
/*c_api.TF_SessionRun(_session, | |||||
run_options: IntPtr.Zero, | |||||
c_api.TF_SessionRun(_session, | |||||
run_options: null, | |||||
inputs: feed_dict.Select(f => f.Key).ToArray(), | inputs: feed_dict.Select(f => f.Key).ToArray(), | ||||
input_values: new IntPtr[] { }, | input_values: new IntPtr[] { }, | ||||
ninputs: 0, | ninputs: 0, | ||||
outputs: fetch_list, | outputs: fetch_list, | ||||
output_values: output_values, | output_values: output_values, | ||||
noutputs: fetch_list.Length, | noutputs: fetch_list.Length, | ||||
target_opers: new IntPtr[] { }, | |||||
target_opers: IntPtr.Zero, | |||||
ntargets: 0, | ntargets: 0, | ||||
run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
status: status);*/ | |||||
status: status); | |||||
var result = output_values.Select(x => c_api.TF_TensorData(x)) | var result = output_values.Select(x => c_api.TF_TensorData(x)) | ||||
.Select(x => (object)*(float*)x) | .Select(x => (object)*(float*)x) | ||||
@@ -8,6 +8,10 @@ namespace Tensorflow | |||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
public Session(string target = "", Graph graph = null) | |||||
{ | |||||
} | |||||
public Session(IntPtr handle) | public Session(IntPtr handle) | ||||
{ | { | ||||
_handle = handle; | _handle = handle; | ||||
@@ -70,10 +70,18 @@ namespace Tensorflow | |||||
/// <param name="ntargets">int</param> | /// <param name="ntargets">int</param> | ||||
/// <param name="run_metadata">TF_Buffer*</param> | /// <param name="run_metadata">TF_Buffer*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||||
IntPtr target_opers, int ntargets, | |||||
IntPtr run_metadata, | |||||
IntPtr status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | ||||
IntPtr inputs, IntPtr input_values, int ninputs, | IntPtr inputs, IntPtr input_values, int ninputs, | ||||
IntPtr outputs, ref IntPtr output_values, int noutputs, | |||||
IntPtr outputs, IntPtr[] output_values, int noutputs, | |||||
IntPtr target_opers, int ntargets, | IntPtr target_opers, int ntargets, | ||||
IntPtr run_metadata, | IntPtr run_metadata, | ||||
IntPtr status); | IntPtr status); | ||||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var dims = new long[rank]; | var dims = new long[rank]; | ||||
for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
shape[i] = c_api.TF_Dim(_handle, i); | |||||
dims[i] = c_api.TF_Dim(_handle, i); | |||||
return dims; | return dims; | ||||
} | } | ||||
@@ -51,7 +51,7 @@ namespace Tensorflow | |||||
public static Session Session() | public static Session Session() | ||||
{ | { | ||||
return (Session)new BaseSession(); | |||||
return new Session(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -79,17 +79,17 @@ namespace TensorFlowNET.UnitTest | |||||
IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; | IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; | ||||
IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; | IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; | ||||
IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; | IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; | ||||
IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; | |||||
IntPtr[] output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; | |||||
IntPtr targets_ptr = IntPtr.Zero; | IntPtr targets_ptr = IntPtr.Zero; | ||||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count, | |||||
outputs_ptr, ref output_values_ptr, outputs_.Count, | |||||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 0, | |||||
outputs_ptr, output_values_ptr, outputs_.Count, | |||||
targets_ptr, targets_.Count, | targets_ptr, targets_.Count, | ||||
IntPtr.Zero, s); | IntPtr.Zero, s); | ||||
s.Check(); | s.Check(); | ||||
output_values_[0] = output_values_ptr; | |||||
output_values_[0] = output_values_ptr[0]; | |||||
} | } | ||||
public IntPtr output_tensor(int i) | public IntPtr output_tensor(int i) | ||||
@@ -77,7 +77,7 @@ namespace TensorFlowNET.UnitTest | |||||
ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | ||||
EXPECT_EQ(DataType.DtInt32, attr_value.Type); | EXPECT_EQ(DataType.DtInt32, attr_value.Type); | ||||
ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | ||||
EXPECT_EQ(2, attr_value.I); | |||||
EXPECT_EQ(2, (int)attr_value.I); | |||||
// Placeholder oper now has a consumer. | // Placeholder oper now has a consumer. | ||||
EXPECT_EQ(1, feed.OutputNumConsumers(0)); | EXPECT_EQ(1, feed.OutputNumConsumers(0)); | ||||
@@ -353,7 +353,7 @@ namespace TensorFlowNET.UnitTest | |||||
c_test_util.Add(feed, scalar, graph, s); | c_test_util.Add(feed, scalar, graph, s); | ||||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
//graph.Dispose(); | |||||
graph.Dispose(); | |||||
s.Dispose(); | s.Dispose(); | ||||
} | } | ||||
@@ -368,12 +368,12 @@ namespace TensorFlowNET.UnitTest | |||||
var graph = new Graph(); | var graph = new Graph(); | ||||
// Create a graph with two nodes: x and 3 | // Create a graph with two nodes: x and 3 | ||||
var feed = c_test_util.Placeholder(graph, s); | |||||
EXPECT_EQ(feed, graph.OperationByName("feed")); | |||||
var scalar = c_test_util.ScalarConst(3, graph, s); | |||||
EXPECT_EQ(scalar, graph.OperationByName("scalar")); | |||||
var neg = c_test_util.Neg(scalar, graph, s); | |||||
EXPECT_EQ(neg, graph.OperationByName("neg")); | |||||
c_test_util.Placeholder(graph, s); | |||||
ASSERT_TRUE(graph.OperationByName("feed") != null); | |||||
var oper = c_test_util.ScalarConst(3, graph, s); | |||||
ASSERT_TRUE(graph.OperationByName("scalar") != null); | |||||
c_test_util.Neg(oper, graph, s); | |||||
ASSERT_TRUE(graph.OperationByName("neg") != null); | |||||
// Export to a GraphDef. | // Export to a GraphDef. | ||||
var graph_def = graph.ToGraphDef(s); | var graph_def = graph.ToGraphDef(s); | ||||
@@ -389,9 +389,9 @@ namespace TensorFlowNET.UnitTest | |||||
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | ||||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
scalar = graph.OperationByName("scalar"); | |||||
feed = graph.OperationByName("feed"); | |||||
neg = graph.OperationByName("neg"); | |||||
var scalar = graph.OperationByName("scalar"); | |||||
var feed = graph.OperationByName("feed"); | |||||
var neg = graph.OperationByName("neg"); | |||||
ASSERT_TRUE(scalar != IntPtr.Zero); | ASSERT_TRUE(scalar != IntPtr.Zero); | ||||
ASSERT_TRUE(feed != IntPtr.Zero); | ASSERT_TRUE(feed != IntPtr.Zero); | ||||
ASSERT_TRUE(neg != IntPtr.Zero); | ASSERT_TRUE(neg != IntPtr.Zero); | ||||
@@ -57,9 +57,9 @@ namespace TensorFlowNET.UnitTest | |||||
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | ||||
EXPECT_EQ(tensor.rank, nd.ndim); | EXPECT_EQ(tensor.rank, nd.ndim); | ||||
EXPECT_EQ(tensor.shape[0], nd.shape[0]); | |||||
EXPECT_EQ(tensor.shape[1], nd.shape[1]); | |||||
EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||||
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||||
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||||
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | ||||
} | } | ||||
@@ -118,8 +118,8 @@ namespace TensorFlowNET.UnitTest | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
EXPECT_EQ(2, num_dims); | EXPECT_EQ(2, num_dims); | ||||
EXPECT_EQ(2, returned_dims[0]); | |||||
EXPECT_EQ(3, returned_dims[1]); | |||||
EXPECT_EQ(2, (int)returned_dims[0]); | |||||
EXPECT_EQ(3, (int)returned_dims[1]); | |||||
// Try to set 'unknown' with same rank on the shape and see that | // Try to set 'unknown' with same rank on the shape and see that | ||||
// it doesn't change. | // it doesn't change. | ||||
@@ -130,8 +130,8 @@ namespace TensorFlowNET.UnitTest | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
EXPECT_EQ(2, num_dims); | EXPECT_EQ(2, num_dims); | ||||
EXPECT_EQ(2, returned_dims[0]); | |||||
EXPECT_EQ(3, returned_dims[1]); | |||||
EXPECT_EQ(2, (int)returned_dims[0]); | |||||
EXPECT_EQ(3, (int)returned_dims[1]); | |||||
// Try to fetch a shape with the wrong num_dims | // Try to fetch a shape with the wrong num_dims | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | ||||
@@ -13,13 +13,6 @@ namespace TensorFlowNET.UnitTest | |||||
public static class c_test_util | public static class c_test_util | ||||
{ | { | ||||
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | ||||
{ | |||||
Operation op = null; | |||||
AddOpHelper(l, r, graph, s, name, ref op, true); | |||||
return op; | |||||
} | |||||
public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) | |||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | var desc = c_api.TF_NewOperation(graph, "AddN", name); | ||||
@@ -31,8 +24,10 @@ namespace TensorFlowNET.UnitTest | |||||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | c_api.TF_AddInputList(desc, inputs, inputs.Length); | ||||
op = c_api.TF_FinishOperation(desc, s); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
return op; | |||||
} | } | ||||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
@@ -196,39 +191,29 @@ namespace TensorFlowNET.UnitTest | |||||
return op; | return op; | ||||
} | } | ||||
public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||||
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | ||||
c_api.TF_SetAttrType(desc, "dtype", dtype); | c_api.TF_SetAttrType(desc, "dtype", dtype); | ||||
if(dims != null) | |||||
if (dims != null) | |||||
{ | { | ||||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | ||||
} | } | ||||
op = c_api.TF_FinishOperation(desc, s); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
} | |||||
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||||
{ | |||||
Operation op = null; | |||||
PlaceholderHelper(graph, s, name, dtype, dims, ref op); | |||||
return op; | return op; | ||||
} | } | ||||
public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op) | |||||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "Const", name); | var desc = c_api.TF_NewOperation(graph, "Const", name); | ||||
c_api.TF_SetAttrTensor(desc, "value", t, s); | c_api.TF_SetAttrTensor(desc, "value", t, s); | ||||
s.Check(); | s.Check(); | ||||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | c_api.TF_SetAttrType(desc, "dtype", t.dtype); | ||||
op = c_api.TF_FinishOperation(desc, s); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
} | |||||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||||
{ | |||||
Operation op = null; | |||||
ConstHelper(t, graph, s, name, ref op); | |||||
return op; | return op; | ||||
} | } | ||||
@@ -247,7 +232,7 @@ namespace TensorFlowNET.UnitTest | |||||
(IntPtr values, IntPtr len, ref bool closure) => | (IntPtr values, IntPtr len, ref bool closure) => | ||||
{ | { | ||||
// Free the original buffer and set flag | // Free the original buffer and set flag | ||||
Marshal.FreeHGlobal(dotHandle); | |||||
// Marshal.FreeHGlobal(dotHandle); | |||||
}, ref deallocator_called); | }, ref deallocator_called); | ||||
} | } | ||||
} | } | ||||