diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index bc559543..6a447463 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -31,6 +31,50 @@ namespace Tensorflow _names_in_use = new Dictionary(); } + public T as_graph_element(T obj, bool allow_tensor = true, bool allow_operation = true) + { + return _as_graph_element_locked(obj, allow_tensor, allow_operation); + } + + private Func _as_graph_element(object obj) + { + return null; + } + + private T _as_graph_element_locked(T obj, bool allow_tensor = true, bool allow_operation = true) + { + string types_str = ""; + + if (allow_tensor && allow_operation) + { + types_str = "Tensor or Operation"; + } + else if (allow_tensor) + { + types_str = "Tensor"; + } + else if (allow_operation) + { + types_str = "Operation"; + } + + var temp_obj = _as_graph_element(obj); + + if(obj is Tensor && allow_tensor) + { + if ((obj as Tensor).graph.Equals(this)) + { + return obj; + } + else + { + throw new Exception($"Tensor {obj} is not an element of this graph."); + } + } + + throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); + } + public unsafe Operation create_op(string op_type, List inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "", Dictionary attrs = null, OpDef op_def = null) diff --git a/src/TensorFlowNET.Core/Operation.cs b/src/TensorFlowNET.Core/Operation.cs index d2bbca5c..16ac6487 100644 --- a/src/TensorFlowNET.Core/Operation.cs +++ b/src/TensorFlowNET.Core/Operation.cs @@ -8,6 +8,7 @@ namespace Tensorflow public class Operation { private Graph _graph; + public Graph graph => _graph; public IntPtr _c_op; public int _id => _id_value; private int _id_value; diff --git a/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs new file mode 100644 index 00000000..310b1cfc --- /dev/null +++ b/src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Fetch mapper for singleton tensors and ops. + /// + public class _ElementFetchMapper : _FetchMapper + { + private List _unique_fetches = new List(); + private Action _contraction_fn; + + public _ElementFetchMapper(List fetches, Action contraction_fn) + { + foreach(var tensor in fetches) + { + var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); + _unique_fetches.Add(fetch); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Session/_FetchHandler.cs b/src/TensorFlowNET.Core/Session/_FetchHandler.cs index c4d4d5ee..35b5e4b6 100644 --- a/src/TensorFlowNET.Core/Session/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Session/_FetchHandler.cs @@ -9,9 +9,11 @@ namespace Tensorflow /// public class _FetchHandler { - public _FetchHandler(Graph graph, Tensor fetches) - { + private _ElementFetchMapper _fetch_mapper; + public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) + { + _fetch_mapper = new _FetchMapper().for_fetch(fetches); } } } diff --git a/src/TensorFlowNET.Core/Session/_FetchMapper.cs b/src/TensorFlowNET.Core/Session/_FetchMapper.cs new file mode 100644 index 00000000..763b67a0 --- /dev/null +++ b/src/TensorFlowNET.Core/Session/_FetchMapper.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class _FetchMapper + { + public _ElementFetchMapper for_fetch(Tensor fetch) + { + var fetches = new List { fetch }; + + return new _ElementFetchMapper(fetches, null); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensor.cs b/src/TensorFlowNET.Core/Tensor.cs index e9737d5c..bdaf8a41 100644 --- a/src/TensorFlowNET.Core/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensor.cs @@ -13,6 +13,8 @@ namespace Tensorflow private DataType _dtype; public DataType dtype => _dtype; + public Graph graph => _op.graph; + public string name; public Tensor(Operation op, int value_index, DataType dtype)