Browse Source

override graph

tags/v0.12
Oceania2018 6 years ago
parent
commit
67949251b2
3 changed files with 5 additions and 5 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs

+ 2
- 2
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -28,9 +28,9 @@ namespace Tensorflow
{
private Func<List<NDArray>, object> _contraction_fn;

public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn, Graph graph = null)
{
var g = ops.get_default_graph();
var g = graph ?? ops.get_default_graph();

foreach(var fetch in fetches)
{


+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow

public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{
_fetch_mapper = _FetchMapper.for_fetch(fetches);
_fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph);
foreach(var fetch in _fetch_mapper.unique_fetches())
{
switch (fetch)


+ 2
- 2
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow
{
protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>();
protected List<int[]> _value_indices = new List<int[]>();
public static _FetchMapper for_fetch(object fetch)
public static _FetchMapper for_fetch(object fetch, Graph graph = null)
{
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };

@@ -34,7 +34,7 @@ namespace Tensorflow
if (fetch.GetType().IsArray)
return new _ListFetchMapper(fetches);
else
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]);
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0], graph: graph);
}

public virtual NDArray[] build_results(List<NDArray> values)


Loading…
Cancel
Save