From 2300e619aaab8bd7b7203d1004ca81b922243d14 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 21 Dec 2018 07:28:06 -0600 Subject: [PATCH] _FetchHandler --- src/TensorFlowNET.Core/Graph.cs | 15 +++++++++++ .../Session/_ElementFetchMapper.cs | 5 ++++ .../Session/_FetchHandler.cs | 25 +++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 6a447463..96131b8a 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -22,6 +22,7 @@ namespace Tensorflow private Dictionary _names_in_use; public int _version; private int _next_id_counter; + private List _unfetchable_ops = new List(); public Graph(IntPtr graph) { @@ -111,6 +112,20 @@ namespace Tensorflow return ++_next_id_counter; } + public bool is_fetchable(T tensor_or_op) + { + if (tensor_or_op is Tensor) + { + return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; + } + else if (tensor_or_op is Operation) + { + return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); + } + + return false; + } + public string unique_name(string name) { var name_key = name.ToLower(); diff --git a/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs index 310b1cfc..94565316 100644 --- a/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs @@ -20,5 +20,10 @@ namespace Tensorflow _unique_fetches.Add(fetch); } } + + public List unique_fetches() + { + return _unique_fetches; + } } } diff --git a/src/TensorFlowNET.Core/Session/_FetchHandler.cs b/src/TensorFlowNET.Core/Session/_FetchHandler.cs index 35b5e4b6..eb14ef69 100644 --- a/src/TensorFlowNET.Core/Session/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Session/_FetchHandler.cs @@ -10,10 +10,35 @@ namespace Tensorflow public class _FetchHandler { private _ElementFetchMapper _fetch_mapper; + private List _fetches = new List(); + private List _ops = new List(); + private List _final_fetches = new List(); public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) { _fetch_mapper = new _FetchMapper().for_fetch(fetches); + foreach(var fetch in _fetch_mapper.unique_fetches()) + { + switch (fetch) + { + case Tensor val: + _assert_fetchable(graph, val.op); + _fetches.Add(fetch); + _ops.Add(false); + break; + } + + } + + _final_fetches = _fetches; + } + + private void _assert_fetchable(Graph graph, Operation op) + { + if (!graph.is_fetchable(op)) + { + throw new Exception($"Operation {op.name} has been marked as not fetchable."); + } } } }