@@ -22,6 +22,7 @@ namespace Tensorflow | |||||
private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
public int _version; | public int _version; | ||||
private int _next_id_counter; | private int _next_id_counter; | ||||
private List<String> _unfetchable_ops = new List<string>(); | |||||
public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
{ | { | ||||
@@ -111,6 +112,20 @@ namespace Tensorflow | |||||
return ++_next_id_counter; | return ++_next_id_counter; | ||||
} | } | ||||
public bool is_fetchable<T>(T tensor_or_op) | |||||
{ | |||||
if (tensor_or_op is Tensor) | |||||
{ | |||||
return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; | |||||
} | |||||
else if (tensor_or_op is Operation) | |||||
{ | |||||
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); | |||||
} | |||||
return false; | |||||
} | |||||
public string unique_name(string name) | public string unique_name(string name) | ||||
{ | { | ||||
var name_key = name.ToLower(); | var name_key = name.ToLower(); | ||||
@@ -20,5 +20,10 @@ namespace Tensorflow | |||||
_unique_fetches.Add(fetch); | _unique_fetches.Add(fetch); | ||||
} | } | ||||
} | } | ||||
public List<Object> unique_fetches() | |||||
{ | |||||
return _unique_fetches; | |||||
} | |||||
} | } | ||||
} | } |
@@ -10,10 +10,35 @@ namespace Tensorflow | |||||
public class _FetchHandler | public class _FetchHandler | ||||
{ | { | ||||
private _ElementFetchMapper _fetch_mapper; | private _ElementFetchMapper _fetch_mapper; | ||||
private List<object> _fetches = new List<object>(); | |||||
private List<bool> _ops = new List<bool>(); | |||||
private List<object> _final_fetches = new List<object>(); | |||||
public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | ||||
{ | { | ||||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | _fetch_mapper = new _FetchMapper().for_fetch(fetches); | ||||
foreach(var fetch in _fetch_mapper.unique_fetches()) | |||||
{ | |||||
switch (fetch) | |||||
{ | |||||
case Tensor val: | |||||
_assert_fetchable(graph, val.op); | |||||
_fetches.Add(fetch); | |||||
_ops.Add(false); | |||||
break; | |||||
} | |||||
} | |||||
_final_fetches = _fetches; | |||||
} | |||||
private void _assert_fetchable(Graph graph, Operation op) | |||||
{ | |||||
if (!graph.is_fetchable(op)) | |||||
{ | |||||
throw new Exception($"Operation {op.name} has been marked as not fetchable."); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |