From e237ba6bff257818b3bed58b2da91fa6da8e8ce6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 10 Feb 2019 10:29:11 -0600 Subject: [PATCH] #172 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 14 +++++----- .../Sessions/BaseSession.cs | 6 ++--- .../Sessions/_ElementFetchMapper.cs | 26 +++++++++++++++---- .../Sessions/_FetchHandler.cs | 14 +++++----- .../Sessions/_FetchMapper.cs | 8 +++--- src/TensorFlowNET.Core/tf.cs | 5 ++++ test/TensorFlowNET.UnitTest/GradientTest.cs | 2 ++ 7 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index c1d5fa03..14ec0f00 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -15,8 +15,8 @@ namespace Tensorflow public partial class Graph : IDisposable { private IntPtr _handle; - private Dictionary _nodes_by_id; - private Dictionary _nodes_by_name; + private Dictionary _nodes_by_id; + public Dictionary _nodes_by_name; private Dictionary _names_in_use; public int _version; private int _next_id_counter; @@ -35,13 +35,13 @@ namespace Tensorflow { _handle = c_api.TF_NewGraph(); Status = new Status(); - _nodes_by_id = new Dictionary(); - _nodes_by_name = new Dictionary(); + _nodes_by_id = new Dictionary(); + _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; } - public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) + public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); } @@ -54,7 +54,7 @@ namespace Tensorflow return null; } - private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) + private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; @@ -294,7 +294,7 @@ namespace Tensorflow return c_api.TF_GraphOperationByName(_handle, operName); } - public Operation[] get_operations() + public ITensorOrOperation[] get_operations() { return _nodes_by_name.Values.Select(x => x).ToArray(); } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 54ba3759..21949aaa 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -36,12 +36,12 @@ namespace Tensorflow } - public virtual NDArray run(T fetches, FeedItem[] feed_dict = null) + public virtual NDArray run(object fetches, FeedItem[] feed_dict = null) { return _run(fetches, feed_dict); } - private NDArray _run(T fetches, FeedItem[] feed_dict = null) + private NDArray _run(object fetches, FeedItem[] feed_dict = null) { var feed_dict_tensor = new Dictionary(); @@ -49,7 +49,7 @@ namespace Tensorflow feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); // Run request and get response. // We need to keep the returned movers alive for the following _do_run(). diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index 7960d200..3221285f 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -8,20 +8,36 @@ namespace Tensorflow /// /// Fetch mapper for singleton tensors and ops. /// - public class _ElementFetchMapper : _FetchMapper + public class _ElementFetchMapper : _FetchMapper { private List _unique_fetches = new List(); private Func, object> _contraction_fn; - public _ElementFetchMapper(List fetches, Func, object> contraction_fn) + public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) { + var g = ops.get_default_graph(); + ITensorOrOperation el = null; + foreach(var fetch in fetches) { - var g = ops.get_default_graph(); - var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); - _unique_fetches.Add(el); + switch(fetch) + { + case Tensor tensor: + el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); + break; + case Operation op: + el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); + break; + case String str: + // Looks like a Tensor name and can be a Tensor. + el = g._nodes_by_name[str]; + break; + default: + throw new NotImplementedException("_ElementFetchMapper"); + } } + _unique_fetches.Add(el); _contraction_fn = contraction_fn; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index f4e699cb..e45e3823 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -8,24 +8,24 @@ namespace Tensorflow /// /// Handler for structured fetches. /// - public class _FetchHandler + public class _FetchHandler { - private _ElementFetchMapper _fetch_mapper; + private _ElementFetchMapper _fetch_mapper; private List _fetches = new List(); private List _ops = new List(); private List _final_fetches = new List(); - private List _targets = new List(); + private List _targets = new List(); - public _FetchHandler(Graph graph, T fetches, Dictionary feeds = null, Action feed_handles = null) + public _FetchHandler(Graph graph, object fetches, Dictionary feeds = null, Action feed_handles = null) { - _fetch_mapper = new _FetchMapper().for_fetch(fetches); + _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) { case Operation val: _assert_fetchable(graph, val); - _targets.Add((T)(object)val); + _targets.Add(val); _ops.Add(true); break; case Tensor val: @@ -82,7 +82,7 @@ namespace Tensorflow return _final_fetches; } - public List targets() + public List targets() { return _targets; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index b5eff215..fc9d6b43 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -4,13 +4,13 @@ using System.Text; namespace Tensorflow { - public class _FetchMapper + public class _FetchMapper { - public _ElementFetchMapper for_fetch(T fetch) + public _ElementFetchMapper for_fetch(object fetch) { - var fetches = new List { fetch }; + var fetches = new object[] { fetch }; - return new _ElementFetchMapper(fetches, (List fetched_vals) => + return new _ElementFetchMapper(fetches, (List fetched_vals) => { return fetched_vals[0]; }); diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index ed5b428c..a3e19fb1 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -47,6 +47,11 @@ namespace Tensorflow return g; } + public static void ResetGraph() + { + g = new Graph(); + } + public static Session Session() { defaultSession = new Session(); diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index bc887764..f633020f 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Gradients() { + tf.ResetGraph(); + var a = tf.constant(0.0); var b = 2.0 * a; Assert.AreEqual(b.name, "mul:0");