@@ -1,6 +1,7 @@ | |||||
using NumSharp.Core; | using NumSharp.Core; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -38,12 +39,14 @@ namespace Tensorflow | |||||
} | } | ||||
public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
{ | { | ||||
return _run(fetches, feed_dict); | |||||
var result = _run(fetches, feed_dict); | |||||
return result; | |||||
} | } | ||||
private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
{ | { | ||||
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | ||||
@@ -66,22 +69,71 @@ namespace Tensorflow | |||||
// Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
var fetch_handler = new _FetchHandler(_graph, fetches); | var fetch_handler = new _FetchHandler(_graph, fetches); | ||||
// Run request and get response. | |||||
// We need to keep the returned movers alive for the following _do_run(). | |||||
// These movers are no longer needed when _do_run() completes, and | |||||
// are deleted when `movers` goes out of scope when this _run() ends. | |||||
var _ = _update_with_movers(); | |||||
var final_fetches = fetch_handler.fetches(); | |||||
var final_targets = fetch_handler.targets(); | |||||
// We only want to really perform the run if fetches or targets are provided, | |||||
// or if the call is a partial run that specifies feeds. | |||||
var results = _do_run(final_fetches); | |||||
return fetch_handler.build_results(null, results); | |||||
} | |||||
private object[] _do_run(List<object> fetch_list) | |||||
{ | |||||
var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); | |||||
return _call_tf_sessionrun(fetches); | |||||
} | |||||
private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) | |||||
{ | |||||
// Ensure any changes to the graph are reflected in the runtime. | |||||
_extend_graph(); | |||||
var status = new Status(); | var status = new Status(); | ||||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
c_api.TF_SessionRun(_session, | c_api.TF_SessionRun(_session, | ||||
run_options: IntPtr.Zero, | run_options: IntPtr.Zero, | ||||
inputs: new TF_Output[] { }, | inputs: new TF_Output[] { }, | ||||
input_values: new IntPtr[] { }, | input_values: new IntPtr[] { }, | ||||
ninputs: 0, | ninputs: 0, | ||||
outputs: new TF_Output[] { new TF_Output() }, | |||||
output_values: new IntPtr[] { }, | |||||
noutputs: 1, | |||||
outputs: fetch_list, | |||||
output_values: output_values, | |||||
noutputs: fetch_list.Length, | |||||
target_opers: new IntPtr[] { }, | target_opers: new IntPtr[] { }, | ||||
ntargets: 1, | |||||
ntargets: 0, | |||||
run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
status: status.Handle); | status: status.Handle); | ||||
return null; | |||||
var result = output_values.Select(x => new Tensor(x).buffer).Select(x => | |||||
{ | |||||
return (object)*(float*)x; | |||||
}).ToArray(); | |||||
return result; | |||||
} | |||||
/// <summary> | |||||
/// If a tensor handle that is fed to a device incompatible placeholder, | |||||
/// we move the tensor to the right device, generate a new tensor handle, | |||||
/// and update feed_dict to use the new handle. | |||||
/// </summary> | |||||
private List<object> _update_with_movers() | |||||
{ | |||||
return new List<object> { }; | |||||
} | |||||
private void _extend_graph() | |||||
{ | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -21,6 +21,11 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public object build_results(object[] values) | |||||
{ | |||||
return values[0]; | |||||
} | |||||
public List<Object> unique_fetches() | public List<Object> unique_fetches() | ||||
{ | { | ||||
return _unique_fetches; | return _unique_fetches; | ||||
@@ -13,6 +13,7 @@ namespace Tensorflow | |||||
private List<object> _fetches = new List<object>(); | private List<object> _fetches = new List<object>(); | ||||
private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
private List<object> _final_fetches = new List<object>(); | private List<object> _final_fetches = new List<object>(); | ||||
private List<object> _targets = new List<object>(); | |||||
public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | ||||
{ | { | ||||
@@ -33,6 +34,11 @@ namespace Tensorflow | |||||
_final_fetches = _fetches; | _final_fetches = _fetches; | ||||
} | } | ||||
public object build_results(Session session, object[] results) | |||||
{ | |||||
return _fetch_mapper.build_results(results); | |||||
} | |||||
private void _assert_fetchable(Graph graph, Operation op) | private void _assert_fetchable(Graph graph, Operation op) | ||||
{ | { | ||||
if (!graph.is_fetchable(op)) | if (!graph.is_fetchable(op)) | ||||
@@ -40,5 +46,15 @@ namespace Tensorflow | |||||
throw new Exception($"Operation {op.name} has been marked as not fetchable."); | throw new Exception($"Operation {op.name} has been marked as not fetchable."); | ||||
} | } | ||||
} | } | ||||
public List<Object> fetches() | |||||
{ | |||||
return _final_fetches; | |||||
} | |||||
public List<Object> targets() | |||||
{ | |||||
return _targets; | |||||
} | |||||
} | } | ||||
} | } |
@@ -17,6 +17,15 @@ namespace Tensorflow | |||||
public string name; | public string name; | ||||
private readonly IntPtr _handle; | |||||
public IntPtr handle => _handle; | |||||
public IntPtr buffer => c_api.TF_TensorData(_handle); | |||||
public Tensor(IntPtr handle) | |||||
{ | |||||
_handle = handle; | |||||
} | |||||
public Tensor(Operation op, int value_index, DataType dtype) | public Tensor(Operation op, int value_index, DataType dtype) | ||||
{ | { | ||||
_op = op; | _op = op; | ||||
@@ -77,6 +77,9 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern unsafe IntPtr TF_TensorData(TF_Tensor tensor); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | ||||
@@ -12,13 +12,7 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void constant() | public void constant() | ||||
{ | { | ||||
var a = tf.constant(4.0f); | |||||
var b = tf.constant(5.0f); | |||||
var c = tf.add(a, b); | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
var o = sess.run(c); | |||||
} | |||||
var x = tf.constant(4.0f); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -28,7 +22,7 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void add() | |||||
public void addInPlaceholder() | |||||
{ | { | ||||
var a = tf.placeholder(tf.float32); | var a = tf.placeholder(tf.float32); | ||||
var b = tf.placeholder(tf.float32); | var b = tf.placeholder(tf.float32); | ||||
@@ -43,5 +37,19 @@ namespace TensorFlowNET.UnitTest | |||||
var o = sess.run(c, feed_dict); | var o = sess.run(c, feed_dict); | ||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public void addInConstant() | |||||
{ | |||||
var a = tf.constant(4.0f); | |||||
var b = tf.constant(5.0f); | |||||
var c = tf.add(a, b); | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
var o = sess.run(c); | |||||
Assert.AreEqual(o, 9.0f); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |