From fca036836102cb084b344d89607bd16f8fd9bcc4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 22 Dec 2018 15:53:06 -0600 Subject: [PATCH] addInConstant unit test successfully --- src/TensorFlowNET.Core/Session/BaseSession.cs | 68 ++++++++++++++++--- .../Session/_ElementFetchMapper.cs | 5 ++ .../Session/_FetchHandler.cs | 16 +++++ src/TensorFlowNET.Core/Tensor.cs | 9 +++ src/TensorFlowNET.Core/c_api.cs | 3 + test/TensorFlowNET.UnitTest/OperationsTest.cs | 24 ++++--- 6 files changed, 109 insertions(+), 16 deletions(-) diff --git a/src/TensorFlowNET.Core/Session/BaseSession.cs b/src/TensorFlowNET.Core/Session/BaseSession.cs index fb61eb0d..b819f662 100644 --- a/src/TensorFlowNET.Core/Session/BaseSession.cs +++ b/src/TensorFlowNET.Core/Session/BaseSession.cs @@ -1,6 +1,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -38,12 +39,14 @@ namespace Tensorflow } - public virtual byte[] run(Tensor fetches, Dictionary feed_dict = null) + public virtual object run(Tensor fetches, Dictionary feed_dict = null) { - return _run(fetches, feed_dict); + var result = _run(fetches, feed_dict); + + return result; } - private unsafe byte[] _run(Tensor fetches, Dictionary feed_dict = null) + private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) { var feed_dict_tensor = new Dictionary(); @@ -66,22 +69,71 @@ namespace Tensorflow // Create a fetch handler to take care of the structure of fetches. var fetch_handler = new _FetchHandler(_graph, fetches); + // Run request and get response. + // We need to keep the returned movers alive for the following _do_run(). + // These movers are no longer needed when _do_run() completes, and + // are deleted when `movers` goes out of scope when this _run() ends. + var _ = _update_with_movers(); + var final_fetches = fetch_handler.fetches(); + var final_targets = fetch_handler.targets(); + + // We only want to really perform the run if fetches or targets are provided, + // or if the call is a partial run that specifies feeds. + var results = _do_run(final_fetches); + + return fetch_handler.build_results(null, results); + } + + private object[] _do_run(List fetch_list) + { + var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); + + return _call_tf_sessionrun(fetches); + } + + private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) + { + // Ensure any changes to the graph are reflected in the runtime. + _extend_graph(); + var status = new Status(); + var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); + c_api.TF_SessionRun(_session, run_options: IntPtr.Zero, inputs: new TF_Output[] { }, input_values: new IntPtr[] { }, ninputs: 0, - outputs: new TF_Output[] { new TF_Output() }, - output_values: new IntPtr[] { }, - noutputs: 1, + outputs: fetch_list, + output_values: output_values, + noutputs: fetch_list.Length, target_opers: new IntPtr[] { }, - ntargets: 1, + ntargets: 0, run_metadata: IntPtr.Zero, status: status.Handle); - return null; + var result = output_values.Select(x => new Tensor(x).buffer).Select(x => + { + return (object)*(float*)x; + }).ToArray(); + + return result; + } + + /// + /// If a tensor handle that is fed to a device incompatible placeholder, + /// we move the tensor to the right device, generate a new tensor handle, + /// and update feed_dict to use the new handle. + /// + private List _update_with_movers() + { + return new List { }; + } + + private void _extend_graph() + { + } } } diff --git a/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs index 94565316..908f516c 100644 --- a/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs @@ -21,6 +21,11 @@ namespace Tensorflow } } + public object build_results(object[] values) + { + return values[0]; + } + public List unique_fetches() { return _unique_fetches; diff --git a/src/TensorFlowNET.Core/Session/_FetchHandler.cs b/src/TensorFlowNET.Core/Session/_FetchHandler.cs index eb14ef69..0ec355d5 100644 --- a/src/TensorFlowNET.Core/Session/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Session/_FetchHandler.cs @@ -13,6 +13,7 @@ namespace Tensorflow private List _fetches = new List(); private List _ops = new List(); private List _final_fetches = new List(); + private List _targets = new List(); public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) { @@ -33,6 +34,11 @@ namespace Tensorflow _final_fetches = _fetches; } + public object build_results(Session session, object[] results) + { + return _fetch_mapper.build_results(results); + } + private void _assert_fetchable(Graph graph, Operation op) { if (!graph.is_fetchable(op)) @@ -40,5 +46,15 @@ namespace Tensorflow throw new Exception($"Operation {op.name} has been marked as not fetchable."); } } + + public List fetches() + { + return _final_fetches; + } + + public List targets() + { + return _targets; + } } } diff --git a/src/TensorFlowNET.Core/Tensor.cs b/src/TensorFlowNET.Core/Tensor.cs index bdaf8a41..9465d1bd 100644 --- a/src/TensorFlowNET.Core/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensor.cs @@ -17,6 +17,15 @@ namespace Tensorflow public string name; + private readonly IntPtr _handle; + public IntPtr handle => _handle; + public IntPtr buffer => c_api.TF_TensorData(_handle); + + public Tensor(IntPtr handle) + { + _handle = handle; + } + public Tensor(Operation op, int value_index, DataType dtype) { _op = op; diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index abd50bf3..6c9acba4 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -77,6 +77,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); + [DllImport(TensorFlowLibName)] + public static extern unsafe IntPtr TF_TensorData(TF_Tensor tensor); + [DllImport(TensorFlowLibName)] public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index f781cb5d..23b17070 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -12,13 +12,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void constant() { - var a = tf.constant(4.0f); - var b = tf.constant(5.0f); - var c = tf.add(a, b); - using (var sess = tf.Session()) - { - var o = sess.run(c); - } + var x = tf.constant(4.0f); } [TestMethod] @@ -28,7 +22,7 @@ namespace TensorFlowNET.UnitTest } [TestMethod] - public void add() + public void addInPlaceholder() { var a = tf.placeholder(tf.float32); var b = tf.placeholder(tf.float32); @@ -43,5 +37,19 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, feed_dict); } } + + [TestMethod] + public void addInConstant() + { + var a = tf.constant(4.0f); + var b = tf.constant(5.0f); + var c = tf.add(a, b); + + using (var sess = tf.Session()) + { + var o = sess.run(c); + Assert.AreEqual(o, 9.0f); + } + } } }