Browse Source

_ElementFetchMapper.build_results

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
ff51917018
12 changed files with 136 additions and 40 deletions
  1. +13
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +24
    -2
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  3. +18
    -0
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  4. +27
    -8
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  5. +21
    -10
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  6. +23
    -8
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  7. +4
    -4
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Sessions/c_api.session.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Variables/variables.py.cs
  10. BIN
      src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll
  11. +1
    -1
      test/TensorFlowNET.UnitTest/CSession.cs
  12. +3
    -3
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 13
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -77,9 +77,9 @@ namespace Tensorflow

var temp_obj = _as_graph_element(obj);

if(obj is Tensor && allow_tensor)
if (obj is Tensor tensor && allow_tensor)
{
if ((obj as Tensor).Graph.Equals(this))
if (tensor.Graph.Equals(this))
{
return obj;
}
@@ -88,6 +88,17 @@ namespace Tensorflow
throw new Exception($"Tensor {obj} is not an element of this graph.");
}
}
else if (obj is Operation op && allow_operation)
{
if (op.Graph.Equals(this))
{
return obj;
}
else
{
throw new Exception($"Operation {obj} is not an element of this graph.");
}
}

throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
}


+ 24
- 2
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
@@ -10,14 +11,35 @@ namespace Tensorflow
{
using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs))
{
name = namescope;

var ops_on_device = new Dictionary<string, Operation[]>();

// Sorts *inputs according to their devices.
foreach (var inp in inputs)
{
ops_on_device[inp.Device] = new Operation[] { inp };
}

// 1-level tree. The root node is the returned NoOp node.
if (ops_on_device.Count == 1)
{
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name);
}

return _GroupControlDeps("", name);
// 2-level tree. The root node is the returned NoOp node.
// deps contains 1 NoOp node for each device.
return null;
}
}

private static Operation _GroupControlDeps(string dev, string name = "")
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "")
{
if (string.IsNullOrEmpty(dev))
{
return gen_control_flow_ops.no_op(name);
}

return null;
}
}


+ 18
- 0
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gen_control_flow_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

public static Operation no_op(string name = "")
{
var _op = _op_def_lib._apply_op_helper("NoOp", name);

return _op;
}
}
}

+ 27
- 8
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -40,7 +40,12 @@ namespace Tensorflow
return _run(fetches, feed_dict);
}

private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
public virtual NDArray run(Operation fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{
return _run(fetches, feed_dict);
}

private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();

@@ -53,7 +58,7 @@ namespace Tensorflow
}

// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor);

// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
@@ -65,20 +70,34 @@ namespace Tensorflow

// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_fetches, feed_dict_tensor);
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);

return fetch_handler.build_results(null, results);
}

private NDArray[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
{
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;

return _call_tf_sessionrun(feeds, fetches);
return _call_tf_sessionrun(feeds, fetches, target_list);
}

private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list)
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
@@ -95,8 +114,8 @@ namespace Tensorflow
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: IntPtr.Zero,
ntargets: 0,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status);



+ 21
- 10
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -8,26 +8,37 @@ namespace Tensorflow
/// <summary>
/// Fetch mapper for singleton tensors and ops.
/// </summary>
public class _ElementFetchMapper : _FetchMapper
public class _ElementFetchMapper<T> : _FetchMapper<T>
{
private List<Object> _unique_fetches = new List<object>();
private Action _contraction_fn;
private List<object> _unique_fetches = new List<object>();
private Func<List<object>> _contraction_fn;

public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn)
public _ElementFetchMapper(List<T> fetches, Func<List<object>> contraction_fn)
{
foreach(var tensor in fetches)
foreach(var fetch in fetches)
{
var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true);
_unique_fetches.Add(fetch);
var g = ops.get_default_graph();
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
_unique_fetches.Add(el);
}

_contraction_fn = contraction_fn;
}

public NDArray build_results(NDArray[] values)
/// <summary>
/// Build results matching the original fetch shape.
/// </summary>
/// <param name="values"></param>
/// <returns></returns>
public NDArray build_results(List<object> values)
{
return values[0];
if (values.Count == 0)
return null;
else
return _contraction_fn(values);
}

public List<Object> unique_fetches()
public List<object> unique_fetches()
{
return _unique_fetches;
}


+ 23
- 8
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -8,21 +8,26 @@ namespace Tensorflow
/// <summary>
/// Handler for structured fetches.
/// </summary>
public class _FetchHandler
public class _FetchHandler<T>
{
private _ElementFetchMapper _fetch_mapper;
private _ElementFetchMapper<T> _fetch_mapper;
private List<Tensor> _fetches = new List<Tensor>();
private List<bool> _ops = new List<bool>();
private List<Tensor> _final_fetches = new List<Tensor>();
private List<object> _targets = new List<object>();
private List<T> _targets = new List<T>();

public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())
{
switch (fetch)
{
case Operation val:
_assert_fetchable(graph, val);
_targets.Add((T)(object)val);
_ops.Add(true);
break;
case Tensor val:
_assert_fetchable(graph, val.op);
_fetches.Add(val);
@@ -35,9 +40,19 @@ namespace Tensorflow
_final_fetches = _fetches;
}

public NDArray build_results(Session session, NDArray[] results)
public NDArray build_results(Session session, NDArray[] tensor_values)
{
return _fetch_mapper.build_results(results);
var full_values = new List<object>();

foreach(var is_op in _ops)
{
if (is_op)
{
full_values.Add(null);
}
}

return _fetch_mapper.build_results(full_values);
}

private void _assert_fetchable(Graph graph, Operation op)
@@ -53,7 +68,7 @@ namespace Tensorflow
return _final_fetches;
}

public List<Object> targets()
public List<T> targets()
{
return _targets;
}


+ 4
- 4
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -4,13 +4,13 @@ using System.Text;

namespace Tensorflow
{
public class _FetchMapper
public class _FetchMapper<T>
{
public _ElementFetchMapper for_fetch(Tensor fetch)
public _ElementFetchMapper<T> for_fetch(T fetch)
{
var fetches = new List<Tensor> { fetch };
var fetches = new List<T> { fetch };

return new _ElementFetchMapper(fetches, null);
return new _ElementFetchMapper<T>(fetches, null);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -87,7 +87,7 @@ namespace Tensorflow
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
TF_Output[] inputs, IntPtr[] input_values, int ninputs,
TF_Output[] outputs, IntPtr[] output_values, int noutputs,
IntPtr target_opers, int ntargets,
IntPtr[] target_opers, int ntargets,
IntPtr run_metadata,
IntPtr status);
}


+ 1
- 1
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow
/// <returns>An Op that run the initializers of all the specified variables.</returns>
public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
{
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList());
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name);
}
}
}

BIN
src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll View File


+ 1
- 1
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -76,7 +76,7 @@ namespace TensorFlowNET.UnitTest
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
IntPtr targets_ptr = IntPtr.Zero;
IntPtr[] targets_ptr = new IntPtr[0];

c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
outputs_ptr, output_values_ptr, outputs_.Count,


+ 3
- 3
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -35,13 +35,13 @@ namespace TensorFlowNET.UnitTest

using (var session = tf.Session())
{
/*session.run(model);
session.run(model);
for(int i = 0; i < 5; i++)
{
x = x + 1;
//x = x + 1;
var result = session.run(x);
print(result);
}*/
}
}

}


Loading…
Cancel
Save