@@ -15,8 +15,8 @@ namespace Tensorflow | |||||
public partial class Graph : IDisposable | public partial class Graph : IDisposable | ||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
private Dictionary<int, Operation> _nodes_by_id; | |||||
private Dictionary<string, Operation> _nodes_by_name; | |||||
private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||||
public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||||
private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
public int _version; | public int _version; | ||||
private int _next_id_counter; | private int _next_id_counter; | ||||
@@ -35,13 +35,13 @@ namespace Tensorflow | |||||
{ | { | ||||
_handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
Status = new Status(); | Status = new Status(); | ||||
_nodes_by_id = new Dictionary<int, Operation>(); | |||||
_nodes_by_name = new Dictionary<string, Operation>(); | |||||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||||
_names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
_graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
} | } | ||||
public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||||
{ | { | ||||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
} | } | ||||
@@ -54,7 +54,7 @@ namespace Tensorflow | |||||
return null; | return null; | ||||
} | } | ||||
private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||||
{ | { | ||||
string types_str = ""; | string types_str = ""; | ||||
@@ -294,7 +294,7 @@ namespace Tensorflow | |||||
return c_api.TF_GraphOperationByName(_handle, operName); | return c_api.TF_GraphOperationByName(_handle, operName); | ||||
} | } | ||||
public Operation[] get_operations() | |||||
public ITensorOrOperation[] get_operations() | |||||
{ | { | ||||
return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
} | } | ||||
@@ -36,12 +36,12 @@ namespace Tensorflow | |||||
} | } | ||||
public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||||
public virtual NDArray run(object fetches, FeedItem[] feed_dict = null) | |||||
{ | { | ||||
return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
} | } | ||||
private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) | |||||
private NDArray _run(object fetches, FeedItem[] feed_dict = null) | |||||
{ | { | ||||
var feed_dict_tensor = new Dictionary<object, object>(); | var feed_dict_tensor = new Dictionary<object, object>(); | ||||
@@ -49,7 +49,7 @@ namespace Tensorflow | |||||
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | ||||
// Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
// Run request and get response. | // Run request and get response. | ||||
// We need to keep the returned movers alive for the following _do_run(). | // We need to keep the returned movers alive for the following _do_run(). | ||||
@@ -8,20 +8,36 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Fetch mapper for singleton tensors and ops. | /// Fetch mapper for singleton tensors and ops. | ||||
/// </summary> | /// </summary> | ||||
public class _ElementFetchMapper<T> : _FetchMapper<T> | |||||
public class _ElementFetchMapper : _FetchMapper | |||||
{ | { | ||||
private List<object> _unique_fetches = new List<object>(); | private List<object> _unique_fetches = new List<object>(); | ||||
private Func<List<object>, object> _contraction_fn; | private Func<List<object>, object> _contraction_fn; | ||||
public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn) | |||||
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||||
{ | { | ||||
var g = ops.get_default_graph(); | |||||
ITensorOrOperation el = null; | |||||
foreach(var fetch in fetches) | foreach(var fetch in fetches) | ||||
{ | { | ||||
var g = ops.get_default_graph(); | |||||
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||||
_unique_fetches.Add(el); | |||||
switch(fetch) | |||||
{ | |||||
case Tensor tensor: | |||||
el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
break; | |||||
case Operation op: | |||||
el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); | |||||
break; | |||||
case String str: | |||||
// Looks like a Tensor name and can be a Tensor. | |||||
el = g._nodes_by_name[str]; | |||||
break; | |||||
default: | |||||
throw new NotImplementedException("_ElementFetchMapper"); | |||||
} | |||||
} | } | ||||
_unique_fetches.Add(el); | |||||
_contraction_fn = contraction_fn; | _contraction_fn = contraction_fn; | ||||
} | } | ||||
@@ -8,24 +8,24 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Handler for structured fetches. | /// Handler for structured fetches. | ||||
/// </summary> | /// </summary> | ||||
public class _FetchHandler<T> | |||||
public class _FetchHandler | |||||
{ | { | ||||
private _ElementFetchMapper<T> _fetch_mapper; | |||||
private _ElementFetchMapper _fetch_mapper; | |||||
private List<Tensor> _fetches = new List<Tensor>(); | private List<Tensor> _fetches = new List<Tensor>(); | ||||
private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
private List<T> _targets = new List<T>(); | |||||
private List<object> _targets = new List<object>(); | |||||
public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||||
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||||
{ | { | ||||
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
{ | { | ||||
switch (fetch) | switch (fetch) | ||||
{ | { | ||||
case Operation val: | case Operation val: | ||||
_assert_fetchable(graph, val); | _assert_fetchable(graph, val); | ||||
_targets.Add((T)(object)val); | |||||
_targets.Add(val); | |||||
_ops.Add(true); | _ops.Add(true); | ||||
break; | break; | ||||
case Tensor val: | case Tensor val: | ||||
@@ -82,7 +82,7 @@ namespace Tensorflow | |||||
return _final_fetches; | return _final_fetches; | ||||
} | } | ||||
public List<T> targets() | |||||
public List<object> targets() | |||||
{ | { | ||||
return _targets; | return _targets; | ||||
} | } | ||||
@@ -4,13 +4,13 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class _FetchMapper<T> | |||||
public class _FetchMapper | |||||
{ | { | ||||
public _ElementFetchMapper<T> for_fetch(T fetch) | |||||
public _ElementFetchMapper for_fetch(object fetch) | |||||
{ | { | ||||
var fetches = new List<T> { fetch }; | |||||
var fetches = new object[] { fetch }; | |||||
return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) => | |||||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||||
{ | { | ||||
return fetched_vals[0]; | return fetched_vals[0]; | ||||
}); | }); | ||||
@@ -47,6 +47,11 @@ namespace Tensorflow | |||||
return g; | return g; | ||||
} | } | ||||
public static void ResetGraph() | |||||
{ | |||||
g = new Graph(); | |||||
} | |||||
public static Session Session() | public static Session Session() | ||||
{ | { | ||||
defaultSession = new Session(); | defaultSession = new Session(); | ||||
@@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void Gradients() | public void Gradients() | ||||
{ | { | ||||
tf.ResetGraph(); | |||||
var a = tf.constant(0.0); | var a = tf.constant(0.0); | ||||
var b = 2.0 * a; | var b = 2.0 * a; | ||||
Assert.AreEqual(b.name, "mul:0"); | Assert.AreEqual(b.name, "mul:0"); | ||||