@@ -31,6 +31,50 @@ namespace Tensorflow | |||||
_names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
} | } | ||||
public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | |||||
{ | |||||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||||
} | |||||
private Func<object> _as_graph_element(object obj) | |||||
{ | |||||
return null; | |||||
} | |||||
private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | |||||
{ | |||||
string types_str = ""; | |||||
if (allow_tensor && allow_operation) | |||||
{ | |||||
types_str = "Tensor or Operation"; | |||||
} | |||||
else if (allow_tensor) | |||||
{ | |||||
types_str = "Tensor"; | |||||
} | |||||
else if (allow_operation) | |||||
{ | |||||
types_str = "Operation"; | |||||
} | |||||
var temp_obj = _as_graph_element(obj); | |||||
if(obj is Tensor && allow_tensor) | |||||
{ | |||||
if ((obj as Tensor).graph.Equals(this)) | |||||
{ | |||||
return obj; | |||||
} | |||||
else | |||||
{ | |||||
throw new Exception($"Tensor {obj} is not an element of this graph."); | |||||
} | |||||
} | |||||
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | |||||
} | |||||
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | ||||
TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
@@ -8,6 +8,7 @@ namespace Tensorflow | |||||
public class Operation | public class Operation | ||||
{ | { | ||||
private Graph _graph; | private Graph _graph; | ||||
public Graph graph => _graph; | |||||
public IntPtr _c_op; | public IntPtr _c_op; | ||||
public int _id => _id_value; | public int _id => _id_value; | ||||
private int _id_value; | private int _id_value; | ||||
@@ -0,0 +1,24 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Fetch mapper for singleton tensors and ops. | |||||
/// </summary> | |||||
public class _ElementFetchMapper : _FetchMapper | |||||
{ | |||||
private List<Object> _unique_fetches = new List<object>(); | |||||
private Action _contraction_fn; | |||||
public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn) | |||||
{ | |||||
foreach(var tensor in fetches) | |||||
{ | |||||
var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
_unique_fetches.Add(fetch); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -9,9 +9,11 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class _FetchHandler | public class _FetchHandler | ||||
{ | { | ||||
public _FetchHandler(Graph graph, Tensor fetches) | |||||
{ | |||||
private _ElementFetchMapper _fetch_mapper; | |||||
public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | |||||
{ | |||||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class _FetchMapper | |||||
{ | |||||
public _ElementFetchMapper for_fetch(Tensor fetch) | |||||
{ | |||||
var fetches = new List<Tensor> { fetch }; | |||||
return new _ElementFetchMapper(fetches, null); | |||||
} | |||||
} | |||||
} |
@@ -13,6 +13,8 @@ namespace Tensorflow | |||||
private DataType _dtype; | private DataType _dtype; | ||||
public DataType dtype => _dtype; | public DataType dtype => _dtype; | ||||
public Graph graph => _op.graph; | |||||
public string name; | public string name; | ||||
public Tensor(Operation op, int value_index, DataType dtype) | public Tensor(Operation op, int value_index, DataType dtype) | ||||