@@ -77,9 +77,9 @@ namespace Tensorflow | |||||
var temp_obj = _as_graph_element(obj); | var temp_obj = _as_graph_element(obj); | ||||
if(obj is Tensor && allow_tensor) | |||||
if (obj is Tensor tensor && allow_tensor) | |||||
{ | { | ||||
if ((obj as Tensor).Graph.Equals(this)) | |||||
if (tensor.Graph.Equals(this)) | |||||
{ | { | ||||
return obj; | return obj; | ||||
} | } | ||||
@@ -88,6 +88,17 @@ namespace Tensorflow | |||||
throw new Exception($"Tensor {obj} is not an element of this graph."); | throw new Exception($"Tensor {obj} is not an element of this graph."); | ||||
} | } | ||||
} | } | ||||
else if (obj is Operation op && allow_operation) | |||||
{ | |||||
if (op.Graph.Equals(this)) | |||||
{ | |||||
return obj; | |||||
} | |||||
else | |||||
{ | |||||
throw new Exception($"Operation {obj} is not an element of this graph."); | |||||
} | |||||
} | |||||
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | ||||
} | } | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -10,14 +11,35 @@ namespace Tensorflow | |||||
{ | { | ||||
using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs)) | using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs)) | ||||
{ | { | ||||
name = namescope; | |||||
var ops_on_device = new Dictionary<string, Operation[]>(); | |||||
// Sorts *inputs according to their devices. | // Sorts *inputs according to their devices. | ||||
foreach (var inp in inputs) | |||||
{ | |||||
ops_on_device[inp.Device] = new Operation[] { inp }; | |||||
} | |||||
// 1-level tree. The root node is the returned NoOp node. | |||||
if (ops_on_device.Count == 1) | |||||
{ | |||||
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||||
} | |||||
return _GroupControlDeps("", name); | |||||
// 2-level tree. The root node is the returned NoOp node. | |||||
// deps contains 1 NoOp node for each device. | |||||
return null; | |||||
} | } | ||||
} | } | ||||
private static Operation _GroupControlDeps(string dev, string name = "") | |||||
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | |||||
{ | { | ||||
if (string.IsNullOrEmpty(dev)) | |||||
{ | |||||
return gen_control_flow_ops.no_op(name); | |||||
} | |||||
return null; | return null; | ||||
} | } | ||||
} | } | ||||
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class gen_control_flow_ops | |||||
{ | |||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
public static Operation no_op(string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("NoOp", name); | |||||
return _op; | |||||
} | |||||
} | |||||
} |
@@ -40,7 +40,12 @@ namespace Tensorflow | |||||
return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
} | } | ||||
private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
public virtual NDArray run(Operation fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
{ | |||||
return _run(fetches, feed_dict); | |||||
} | |||||
private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
{ | { | ||||
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | ||||
@@ -53,7 +58,7 @@ namespace Tensorflow | |||||
} | } | ||||
// Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||||
// Run request and get response. | // Run request and get response. | ||||
// We need to keep the returned movers alive for the following _do_run(). | // We need to keep the returned movers alive for the following _do_run(). | ||||
@@ -65,20 +70,34 @@ namespace Tensorflow | |||||
// We only want to really perform the run if fetches or targets are provided, | // We only want to really perform the run if fetches or targets are provided, | ||||
// or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
var results = _do_run(final_fetches, feed_dict_tensor); | |||||
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); | |||||
return fetch_handler.build_results(null, results); | return fetch_handler.build_results(null, results); | ||||
} | } | ||||
private NDArray[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
/// <summary> | |||||
/// Runs a step based on the given fetches and feeds. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
/// <param name="fetch_list"></param> | |||||
/// <param name="feed_dict"></param> | |||||
/// <returns> | |||||
/// A list of numpy ndarrays, corresponding to the elements of | |||||
/// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
/// name of an operation, the first Tensor output of that operation | |||||
/// will be returned for that element. | |||||
/// </returns> | |||||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
{ | { | ||||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | ||||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | ||||
var targets = target_list; | |||||
return _call_tf_sessionrun(feeds, fetches); | |||||
return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
} | } | ||||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list) | |||||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
{ | { | ||||
// Ensure any changes to the graph are reflected in the runtime. | // Ensure any changes to the graph are reflected in the runtime. | ||||
_extend_graph(); | _extend_graph(); | ||||
@@ -95,8 +114,8 @@ namespace Tensorflow | |||||
outputs: fetch_list, | outputs: fetch_list, | ||||
output_values: output_values, | output_values: output_values, | ||||
noutputs: fetch_list.Length, | noutputs: fetch_list.Length, | ||||
target_opers: IntPtr.Zero, | |||||
ntargets: 0, | |||||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
ntargets: target_list.Count, | |||||
run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
status: status); | status: status); | ||||
@@ -8,26 +8,37 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Fetch mapper for singleton tensors and ops. | /// Fetch mapper for singleton tensors and ops. | ||||
/// </summary> | /// </summary> | ||||
public class _ElementFetchMapper : _FetchMapper | |||||
public class _ElementFetchMapper<T> : _FetchMapper<T> | |||||
{ | { | ||||
private List<Object> _unique_fetches = new List<object>(); | |||||
private Action _contraction_fn; | |||||
private List<object> _unique_fetches = new List<object>(); | |||||
private Func<List<object>> _contraction_fn; | |||||
public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn) | |||||
public _ElementFetchMapper(List<T> fetches, Func<List<object>> contraction_fn) | |||||
{ | { | ||||
foreach(var tensor in fetches) | |||||
foreach(var fetch in fetches) | |||||
{ | { | ||||
var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
_unique_fetches.Add(fetch); | |||||
var g = ops.get_default_graph(); | |||||
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||||
_unique_fetches.Add(el); | |||||
} | } | ||||
_contraction_fn = contraction_fn; | |||||
} | } | ||||
public NDArray build_results(NDArray[] values) | |||||
/// <summary> | |||||
/// Build results matching the original fetch shape. | |||||
/// </summary> | |||||
/// <param name="values"></param> | |||||
/// <returns></returns> | |||||
public NDArray build_results(List<object> values) | |||||
{ | { | ||||
return values[0]; | |||||
if (values.Count == 0) | |||||
return null; | |||||
else | |||||
return _contraction_fn(values); | |||||
} | } | ||||
public List<Object> unique_fetches() | |||||
public List<object> unique_fetches() | |||||
{ | { | ||||
return _unique_fetches; | return _unique_fetches; | ||||
} | } | ||||
@@ -8,21 +8,26 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Handler for structured fetches. | /// Handler for structured fetches. | ||||
/// </summary> | /// </summary> | ||||
public class _FetchHandler | |||||
public class _FetchHandler<T> | |||||
{ | { | ||||
private _ElementFetchMapper _fetch_mapper; | |||||
private _ElementFetchMapper<T> _fetch_mapper; | |||||
private List<Tensor> _fetches = new List<Tensor>(); | private List<Tensor> _fetches = new List<Tensor>(); | ||||
private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
private List<object> _targets = new List<object>(); | |||||
private List<T> _targets = new List<T>(); | |||||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
{ | { | ||||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||||
foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
{ | { | ||||
switch (fetch) | switch (fetch) | ||||
{ | { | ||||
case Operation val: | |||||
_assert_fetchable(graph, val); | |||||
_targets.Add((T)(object)val); | |||||
_ops.Add(true); | |||||
break; | |||||
case Tensor val: | case Tensor val: | ||||
_assert_fetchable(graph, val.op); | _assert_fetchable(graph, val.op); | ||||
_fetches.Add(val); | _fetches.Add(val); | ||||
@@ -35,9 +40,19 @@ namespace Tensorflow | |||||
_final_fetches = _fetches; | _final_fetches = _fetches; | ||||
} | } | ||||
public NDArray build_results(Session session, NDArray[] results) | |||||
public NDArray build_results(Session session, NDArray[] tensor_values) | |||||
{ | { | ||||
return _fetch_mapper.build_results(results); | |||||
var full_values = new List<object>(); | |||||
foreach(var is_op in _ops) | |||||
{ | |||||
if (is_op) | |||||
{ | |||||
full_values.Add(null); | |||||
} | |||||
} | |||||
return _fetch_mapper.build_results(full_values); | |||||
} | } | ||||
private void _assert_fetchable(Graph graph, Operation op) | private void _assert_fetchable(Graph graph, Operation op) | ||||
@@ -53,7 +68,7 @@ namespace Tensorflow | |||||
return _final_fetches; | return _final_fetches; | ||||
} | } | ||||
public List<Object> targets() | |||||
public List<T> targets() | |||||
{ | { | ||||
return _targets; | return _targets; | ||||
} | } | ||||
@@ -4,13 +4,13 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class _FetchMapper | |||||
public class _FetchMapper<T> | |||||
{ | { | ||||
public _ElementFetchMapper for_fetch(Tensor fetch) | |||||
public _ElementFetchMapper<T> for_fetch(T fetch) | |||||
{ | { | ||||
var fetches = new List<Tensor> { fetch }; | |||||
var fetches = new List<T> { fetch }; | |||||
return new _ElementFetchMapper(fetches, null); | |||||
return new _ElementFetchMapper<T>(fetches, null); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -87,7 +87,7 @@ namespace Tensorflow | |||||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | ||||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | TF_Output[] inputs, IntPtr[] input_values, int ninputs, | ||||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | TF_Output[] outputs, IntPtr[] output_values, int noutputs, | ||||
IntPtr target_opers, int ntargets, | |||||
IntPtr[] target_opers, int ntargets, | |||||
IntPtr run_metadata, | IntPtr run_metadata, | ||||
IntPtr status); | IntPtr status); | ||||
} | } | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow | |||||
/// <returns>An Op that run the initializers of all the specified variables.</returns> | /// <returns>An Op that run the initializers of all the specified variables.</returns> | ||||
public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | ||||
{ | { | ||||
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList()); | |||||
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -76,7 +76,7 @@ namespace TensorFlowNET.UnitTest | |||||
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | ||||
var outputs_ptr = outputs_.ToArray(); | var outputs_ptr = outputs_.ToArray(); | ||||
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | ||||
IntPtr targets_ptr = IntPtr.Zero; | |||||
IntPtr[] targets_ptr = new IntPtr[0]; | |||||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | ||||
outputs_ptr, output_values_ptr, outputs_.Count, | outputs_ptr, output_values_ptr, outputs_.Count, | ||||
@@ -35,13 +35,13 @@ namespace TensorFlowNET.UnitTest | |||||
using (var session = tf.Session()) | using (var session = tf.Session()) | ||||
{ | { | ||||
/*session.run(model); | |||||
session.run(model); | |||||
for(int i = 0; i < 5; i++) | for(int i = 0; i < 5; i++) | ||||
{ | { | ||||
x = x + 1; | |||||
//x = x + 1; | |||||
var result = session.run(x); | var result = session.run(x); | ||||
print(result); | print(result); | ||||
}*/ | |||||
} | |||||
} | } | ||||
} | } | ||||