|
|
@@ -16,7 +16,7 @@ namespace Tensorflow |
|
|
|
private List<Tensor> _final_fetches = new List<Tensor>(); |
|
|
|
private List<T> _targets = new List<T>(); |
|
|
|
|
|
|
|
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) |
|
|
|
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null) |
|
|
|
{ |
|
|
|
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); |
|
|
|
foreach(var fetch in _fetch_mapper.unique_fetches()) |
|
|
@@ -40,18 +40,32 @@ namespace Tensorflow |
|
|
|
_final_fetches = _fetches; |
|
|
|
} |
|
|
|
|
|
|
|
public NDArray build_results(Session session, NDArray[] tensor_values) |
|
|
|
public NDArray build_results(BaseSession session, NDArray[] tensor_values) |
|
|
|
{ |
|
|
|
var full_values = new List<object>(); |
|
|
|
if (_final_fetches.Count != tensor_values.Length) |
|
|
|
throw new InvalidOperationException("_final_fetches mismatch tensor_values"); |
|
|
|
|
|
|
|
int i = 0; |
|
|
|
int j = 0; |
|
|
|
foreach(var is_op in _ops) |
|
|
|
{ |
|
|
|
if (is_op) |
|
|
|
{ |
|
|
|
full_values.Add(null); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
var value = tensor_values[j]; |
|
|
|
j += 1; |
|
|
|
full_values.Add(value); |
|
|
|
} |
|
|
|
i += 1; |
|
|
|
} |
|
|
|
|
|
|
|
if (j != tensor_values.Length) |
|
|
|
throw new InvalidOperationException("j mismatch tensor_values"); |
|
|
|
|
|
|
|
return _fetch_mapper.build_results(full_values); |
|
|
|
} |
|
|
|
|
|
|
|