|
|
@@ -1,6 +1,7 @@ |
|
|
|
using NumSharp.Core; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
@@ -38,12 +39,14 @@ namespace Tensorflow |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) |
|
|
|
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) |
|
|
|
{ |
|
|
|
return _run(fetches, feed_dict); |
|
|
|
var result = _run(fetches, feed_dict); |
|
|
|
|
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) |
|
|
|
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) |
|
|
|
{ |
|
|
|
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); |
|
|
|
|
|
|
@@ -66,22 +69,71 @@ namespace Tensorflow |
|
|
|
// Create a fetch handler to take care of the structure of fetches. |
|
|
|
var fetch_handler = new _FetchHandler(_graph, fetches); |
|
|
|
|
|
|
|
// Run request and get response. |
|
|
|
// We need to keep the returned movers alive for the following _do_run(). |
|
|
|
// These movers are no longer needed when _do_run() completes, and |
|
|
|
// are deleted when `movers` goes out of scope when this _run() ends. |
|
|
|
var _ = _update_with_movers(); |
|
|
|
var final_fetches = fetch_handler.fetches(); |
|
|
|
var final_targets = fetch_handler.targets(); |
|
|
|
|
|
|
|
// We only want to really perform the run if fetches or targets are provided, |
|
|
|
// or if the call is a partial run that specifies feeds. |
|
|
|
var results = _do_run(final_fetches); |
|
|
|
|
|
|
|
return fetch_handler.build_results(null, results); |
|
|
|
} |
|
|
|
|
|
|
|
private object[] _do_run(List<object> fetch_list) |
|
|
|
{ |
|
|
|
var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); |
|
|
|
|
|
|
|
return _call_tf_sessionrun(fetches); |
|
|
|
} |
|
|
|
|
|
|
|
private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) |
|
|
|
{ |
|
|
|
// Ensure any changes to the graph are reflected in the runtime. |
|
|
|
_extend_graph(); |
|
|
|
|
|
|
|
var status = new Status(); |
|
|
|
|
|
|
|
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); |
|
|
|
|
|
|
|
c_api.TF_SessionRun(_session, |
|
|
|
run_options: IntPtr.Zero, |
|
|
|
inputs: new TF_Output[] { }, |
|
|
|
input_values: new IntPtr[] { }, |
|
|
|
ninputs: 0, |
|
|
|
outputs: new TF_Output[] { new TF_Output() }, |
|
|
|
output_values: new IntPtr[] { }, |
|
|
|
noutputs: 1, |
|
|
|
outputs: fetch_list, |
|
|
|
output_values: output_values, |
|
|
|
noutputs: fetch_list.Length, |
|
|
|
target_opers: new IntPtr[] { }, |
|
|
|
ntargets: 1, |
|
|
|
ntargets: 0, |
|
|
|
run_metadata: IntPtr.Zero, |
|
|
|
status: status.Handle); |
|
|
|
|
|
|
|
return null; |
|
|
|
var result = output_values.Select(x => new Tensor(x).buffer).Select(x => |
|
|
|
{ |
|
|
|
return (object)*(float*)x; |
|
|
|
}).ToArray(); |
|
|
|
|
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// If a tensor handle that is fed to a device incompatible placeholder, |
|
|
|
/// we move the tensor to the right device, generate a new tensor handle, |
|
|
|
/// and update feed_dict to use the new handle. |
|
|
|
/// </summary> |
|
|
|
private List<object> _update_with_movers() |
|
|
|
{ |
|
|
|
return new List<object> { }; |
|
|
|
} |
|
|
|
|
|
|
|
private void _extend_graph() |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |