Browse Source

_FetchHandler

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
2300e619aa
3 changed files with 45 additions and 0 deletions
  1. +15
    -0
      src/TensorFlowNET.Core/Graph.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs
  3. +25
    -0
      src/TensorFlowNET.Core/Session/_FetchHandler.cs

+ 15
- 0
src/TensorFlowNET.Core/Graph.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow
private Dictionary<string, int> _names_in_use;
public int _version;
private int _next_id_counter;
private List<String> _unfetchable_ops = new List<string>();

public Graph(IntPtr graph)
{
@@ -111,6 +112,20 @@ namespace Tensorflow
return ++_next_id_counter;
}

public bool is_fetchable<T>(T tensor_or_op)
{
if (tensor_or_op is Tensor)
{
return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ;
}
else if (tensor_or_op is Operation)
{
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name);
}

return false;
}

public string unique_name(string name)
{
var name_key = name.ToLower();


+ 5
- 0
src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs View File

@@ -20,5 +20,10 @@ namespace Tensorflow
_unique_fetches.Add(fetch);
}
}

public List<Object> unique_fetches()
{
return _unique_fetches;
}
}
}

+ 25
- 0
src/TensorFlowNET.Core/Session/_FetchHandler.cs View File

@@ -10,10 +10,35 @@ namespace Tensorflow
public class _FetchHandler
{
private _ElementFetchMapper _fetch_mapper;
private List<object> _fetches = new List<object>();
private List<bool> _ops = new List<bool>();
private List<object> _final_fetches = new List<object>();

public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())
{
switch (fetch)
{
case Tensor val:
_assert_fetchable(graph, val.op);
_fetches.Add(fetch);
_ops.Add(false);
break;
}

}

_final_fetches = _fetches;
}

private void _assert_fetchable(Graph graph, Operation op)
{
if (!graph.is_fetchable(op))
{
throw new Exception($"Operation {op.name} has been marked as not fetchable.");
}
}
}
}

Loading…
Cancel
Save