@@ -15,8 +15,8 @@ namespace Tensorflow | |||
public partial class Graph : IDisposable | |||
{ | |||
private IntPtr _handle; | |||
private Dictionary<int, Operation> _nodes_by_id; | |||
private Dictionary<string, Operation> _nodes_by_name; | |||
private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||
private Dictionary<string, int> _names_in_use; | |||
public int _version; | |||
private int _next_id_counter; | |||
@@ -35,13 +35,13 @@ namespace Tensorflow | |||
{ | |||
_handle = c_api.TF_NewGraph(); | |||
Status = new Status(); | |||
_nodes_by_id = new Dictionary<int, Operation>(); | |||
_nodes_by_name = new Dictionary<string, Operation>(); | |||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
_names_in_use = new Dictionary<string, int>(); | |||
_graph_key = $"grap-key-{ops.uid()}/"; | |||
} | |||
public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||
{ | |||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||
} | |||
@@ -54,7 +54,7 @@ namespace Tensorflow | |||
return null; | |||
} | |||
private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||
{ | |||
string types_str = ""; | |||
@@ -294,7 +294,7 @@ namespace Tensorflow | |||
return c_api.TF_GraphOperationByName(_handle, operName); | |||
} | |||
public Operation[] get_operations() | |||
public ITensorOrOperation[] get_operations() | |||
{ | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
} | |||
@@ -36,12 +36,12 @@ namespace Tensorflow | |||
} | |||
public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||
public virtual NDArray run(object fetches, FeedItem[] feed_dict = null) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) | |||
private NDArray _run(object fetches, FeedItem[] feed_dict = null) | |||
{ | |||
var feed_dict_tensor = new Dictionary<object, object>(); | |||
@@ -49,7 +49,7 @@ namespace Tensorflow | |||
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | |||
// Create a fetch handler to take care of the structure of fetches. | |||
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
// Run request and get response. | |||
// We need to keep the returned movers alive for the following _do_run(). | |||
@@ -8,20 +8,36 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Fetch mapper for singleton tensors and ops. | |||
/// </summary> | |||
public class _ElementFetchMapper<T> : _FetchMapper<T> | |||
public class _ElementFetchMapper : _FetchMapper | |||
{ | |||
private List<object> _unique_fetches = new List<object>(); | |||
private Func<List<object>, object> _contraction_fn; | |||
public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn) | |||
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||
{ | |||
var g = ops.get_default_graph(); | |||
ITensorOrOperation el = null; | |||
foreach(var fetch in fetches) | |||
{ | |||
var g = ops.get_default_graph(); | |||
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||
_unique_fetches.Add(el); | |||
switch(fetch) | |||
{ | |||
case Tensor tensor: | |||
el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||
break; | |||
case Operation op: | |||
el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); | |||
break; | |||
case String str: | |||
// Looks like a Tensor name and can be a Tensor. | |||
el = g._nodes_by_name[str]; | |||
break; | |||
default: | |||
throw new NotImplementedException("_ElementFetchMapper"); | |||
} | |||
} | |||
_unique_fetches.Add(el); | |||
_contraction_fn = contraction_fn; | |||
} | |||
@@ -8,24 +8,24 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Handler for structured fetches. | |||
/// </summary> | |||
public class _FetchHandler<T> | |||
public class _FetchHandler | |||
{ | |||
private _ElementFetchMapper<T> _fetch_mapper; | |||
private _ElementFetchMapper _fetch_mapper; | |||
private List<Tensor> _fetches = new List<Tensor>(); | |||
private List<bool> _ops = new List<bool>(); | |||
private List<Tensor> _final_fetches = new List<Tensor>(); | |||
private List<T> _targets = new List<T>(); | |||
private List<object> _targets = new List<object>(); | |||
public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
{ | |||
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
{ | |||
switch (fetch) | |||
{ | |||
case Operation val: | |||
_assert_fetchable(graph, val); | |||
_targets.Add((T)(object)val); | |||
_targets.Add(val); | |||
_ops.Add(true); | |||
break; | |||
case Tensor val: | |||
@@ -82,7 +82,7 @@ namespace Tensorflow | |||
return _final_fetches; | |||
} | |||
public List<T> targets() | |||
public List<object> targets() | |||
{ | |||
return _targets; | |||
} | |||
@@ -4,13 +4,13 @@ using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class _FetchMapper<T> | |||
public class _FetchMapper | |||
{ | |||
public _ElementFetchMapper<T> for_fetch(T fetch) | |||
public _ElementFetchMapper for_fetch(object fetch) | |||
{ | |||
var fetches = new List<T> { fetch }; | |||
var fetches = new object[] { fetch }; | |||
return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) => | |||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||
{ | |||
return fetched_vals[0]; | |||
}); | |||
@@ -47,6 +47,11 @@ namespace Tensorflow | |||
return g; | |||
} | |||
public static void ResetGraph() | |||
{ | |||
g = new Graph(); | |||
} | |||
public static Session Session() | |||
{ | |||
defaultSession = new Session(); | |||
@@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest | |||
[TestMethod] | |||
public void Gradients() | |||
{ | |||
tf.ResetGraph(); | |||
var a = tf.constant(0.0); | |||
var b = 2.0 * a; | |||
Assert.AreEqual(b.name, "mul:0"); | |||