@@ -77,9 +77,9 @@ namespace Tensorflow | |||
var temp_obj = _as_graph_element(obj); | |||
if(obj is Tensor && allow_tensor) | |||
if (obj is Tensor tensor && allow_tensor) | |||
{ | |||
if ((obj as Tensor).Graph.Equals(this)) | |||
if (tensor.Graph.Equals(this)) | |||
{ | |||
return obj; | |||
} | |||
@@ -88,6 +88,17 @@ namespace Tensorflow | |||
throw new Exception($"Tensor {obj} is not an element of this graph."); | |||
} | |||
} | |||
else if (obj is Operation op && allow_operation) | |||
{ | |||
if (op.Graph.Equals(this)) | |||
{ | |||
return obj; | |||
} | |||
else | |||
{ | |||
throw new Exception($"Operation {obj} is not an element of this graph."); | |||
} | |||
} | |||
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | |||
} | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -10,14 +11,35 @@ namespace Tensorflow | |||
{ | |||
using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs)) | |||
{ | |||
name = namescope; | |||
var ops_on_device = new Dictionary<string, Operation[]>(); | |||
// Sorts *inputs according to their devices. | |||
foreach (var inp in inputs) | |||
{ | |||
ops_on_device[inp.Device] = new Operation[] { inp }; | |||
} | |||
// 1-level tree. The root node is the returned NoOp node. | |||
if (ops_on_device.Count == 1) | |||
{ | |||
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||
} | |||
return _GroupControlDeps("", name); | |||
// 2-level tree. The root node is the returned NoOp node. | |||
// deps contains 1 NoOp node for each device. | |||
return null; | |||
} | |||
} | |||
private static Operation _GroupControlDeps(string dev, string name = "") | |||
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | |||
{ | |||
if (string.IsNullOrEmpty(dev)) | |||
{ | |||
return gen_control_flow_ops.no_op(name); | |||
} | |||
return null; | |||
} | |||
} | |||
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class gen_control_flow_ops | |||
{ | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
public static Operation no_op(string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("NoOp", name); | |||
return _op; | |||
} | |||
} | |||
} |
@@ -40,7 +40,12 @@ namespace Tensorflow | |||
return _run(fetches, feed_dict); | |||
} | |||
private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
public virtual NDArray run(Operation fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
{ | |||
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||
@@ -53,7 +58,7 @@ namespace Tensorflow | |||
} | |||
// Create a fetch handler to take care of the structure of fetches. | |||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||
// Run request and get response. | |||
// We need to keep the returned movers alive for the following _do_run(). | |||
@@ -65,20 +70,34 @@ namespace Tensorflow | |||
// We only want to really perform the run if fetches or targets are provided, | |||
// or if the call is a partial run that specifies feeds. | |||
var results = _do_run(final_fetches, feed_dict_tensor); | |||
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); | |||
return fetch_handler.build_results(null, results); | |||
} | |||
private NDArray[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||
/// <summary> | |||
/// Runs a step based on the given fetches and feeds. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
/// <param name="fetch_list"></param> | |||
/// <param name="feed_dict"></param> | |||
/// <returns> | |||
/// A list of numpy ndarrays, corresponding to the elements of | |||
/// `fetch_list`. If the ith element of `fetch_list` contains the | |||
/// name of an operation, the first Tensor output of that operation | |||
/// will be returned for that element. | |||
/// </returns> | |||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||
{ | |||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
var targets = target_list; | |||
return _call_tf_sessionrun(feeds, fetches); | |||
return _call_tf_sessionrun(feeds, fetches, target_list); | |||
} | |||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list) | |||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
{ | |||
// Ensure any changes to the graph are reflected in the runtime. | |||
_extend_graph(); | |||
@@ -95,8 +114,8 @@ namespace Tensorflow | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: fetch_list.Length, | |||
target_opers: IntPtr.Zero, | |||
ntargets: 0, | |||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
ntargets: target_list.Count, | |||
run_metadata: IntPtr.Zero, | |||
status: status); | |||
@@ -8,26 +8,37 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Fetch mapper for singleton tensors and ops. | |||
/// </summary> | |||
public class _ElementFetchMapper : _FetchMapper | |||
public class _ElementFetchMapper<T> : _FetchMapper<T> | |||
{ | |||
private List<Object> _unique_fetches = new List<object>(); | |||
private Action _contraction_fn; | |||
private List<object> _unique_fetches = new List<object>(); | |||
private Func<List<object>> _contraction_fn; | |||
public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn) | |||
public _ElementFetchMapper(List<T> fetches, Func<List<object>> contraction_fn) | |||
{ | |||
foreach(var tensor in fetches) | |||
foreach(var fetch in fetches) | |||
{ | |||
var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||
_unique_fetches.Add(fetch); | |||
var g = ops.get_default_graph(); | |||
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||
_unique_fetches.Add(el); | |||
} | |||
_contraction_fn = contraction_fn; | |||
} | |||
public NDArray build_results(NDArray[] values) | |||
/// <summary> | |||
/// Build results matching the original fetch shape. | |||
/// </summary> | |||
/// <param name="values"></param> | |||
/// <returns></returns> | |||
public NDArray build_results(List<object> values) | |||
{ | |||
return values[0]; | |||
if (values.Count == 0) | |||
return null; | |||
else | |||
return _contraction_fn(values); | |||
} | |||
public List<Object> unique_fetches() | |||
public List<object> unique_fetches() | |||
{ | |||
return _unique_fetches; | |||
} | |||
@@ -8,21 +8,26 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Handler for structured fetches. | |||
/// </summary> | |||
public class _FetchHandler | |||
public class _FetchHandler<T> | |||
{ | |||
private _ElementFetchMapper _fetch_mapper; | |||
private _ElementFetchMapper<T> _fetch_mapper; | |||
private List<Tensor> _fetches = new List<Tensor>(); | |||
private List<bool> _ops = new List<bool>(); | |||
private List<Tensor> _final_fetches = new List<Tensor>(); | |||
private List<object> _targets = new List<object>(); | |||
private List<T> _targets = new List<T>(); | |||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||
{ | |||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||
foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
{ | |||
switch (fetch) | |||
{ | |||
case Operation val: | |||
_assert_fetchable(graph, val); | |||
_targets.Add((T)(object)val); | |||
_ops.Add(true); | |||
break; | |||
case Tensor val: | |||
_assert_fetchable(graph, val.op); | |||
_fetches.Add(val); | |||
@@ -35,9 +40,19 @@ namespace Tensorflow | |||
_final_fetches = _fetches; | |||
} | |||
public NDArray build_results(Session session, NDArray[] results) | |||
public NDArray build_results(Session session, NDArray[] tensor_values) | |||
{ | |||
return _fetch_mapper.build_results(results); | |||
var full_values = new List<object>(); | |||
foreach(var is_op in _ops) | |||
{ | |||
if (is_op) | |||
{ | |||
full_values.Add(null); | |||
} | |||
} | |||
return _fetch_mapper.build_results(full_values); | |||
} | |||
private void _assert_fetchable(Graph graph, Operation op) | |||
@@ -53,7 +68,7 @@ namespace Tensorflow | |||
return _final_fetches; | |||
} | |||
public List<Object> targets() | |||
public List<T> targets() | |||
{ | |||
return _targets; | |||
} | |||
@@ -4,13 +4,13 @@ using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class _FetchMapper | |||
public class _FetchMapper<T> | |||
{ | |||
public _ElementFetchMapper for_fetch(Tensor fetch) | |||
public _ElementFetchMapper<T> for_fetch(T fetch) | |||
{ | |||
var fetches = new List<Tensor> { fetch }; | |||
var fetches = new List<T> { fetch }; | |||
return new _ElementFetchMapper(fetches, null); | |||
return new _ElementFetchMapper<T>(fetches, null); | |||
} | |||
} | |||
} |
@@ -87,7 +87,7 @@ namespace Tensorflow | |||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||
IntPtr target_opers, int ntargets, | |||
IntPtr[] target_opers, int ntargets, | |||
IntPtr run_metadata, | |||
IntPtr status); | |||
} | |||
@@ -42,7 +42,7 @@ namespace Tensorflow | |||
/// <returns>An Op that run the initializers of all the specified variables.</returns> | |||
public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | |||
{ | |||
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList()); | |||
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name); | |||
} | |||
} | |||
} |
@@ -76,7 +76,7 @@ namespace TensorFlowNET.UnitTest | |||
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | |||
var outputs_ptr = outputs_.ToArray(); | |||
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | |||
IntPtr targets_ptr = IntPtr.Zero; | |||
IntPtr[] targets_ptr = new IntPtr[0]; | |||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | |||
outputs_ptr, output_values_ptr, outputs_.Count, | |||
@@ -35,13 +35,13 @@ namespace TensorFlowNET.UnitTest | |||
using (var session = tf.Session()) | |||
{ | |||
/*session.run(model); | |||
session.run(model); | |||
for(int i = 0; i < 5; i++) | |||
{ | |||
x = x + 1; | |||
//x = x + 1; | |||
var result = session.run(x); | |||
print(result); | |||
}*/ | |||
} | |||
} | |||
} | |||