|
|
@@ -41,17 +41,20 @@ namespace Tensorflow |
|
|
|
_graph_key = $"grap-key-{ops.uid()}/"; |
|
|
|
} |
|
|
|
|
|
|
|
public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) |
|
|
|
public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) |
|
|
|
{ |
|
|
|
return _as_graph_element_locked(obj, allow_tensor, allow_operation); |
|
|
|
} |
|
|
|
|
|
|
|
private Func<object> _as_graph_element(object obj) |
|
|
|
private Tensor _as_graph_element(object obj) |
|
|
|
{ |
|
|
|
if (obj is RefVariable var) |
|
|
|
return var._as_graph_element(); |
|
|
|
|
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allow_operation = true) |
|
|
|
private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) |
|
|
|
{ |
|
|
|
string types_str = ""; |
|
|
|
|
|
|
@@ -69,12 +72,14 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
|
|
|
|
var temp_obj = _as_graph_element(obj); |
|
|
|
if (temp_obj != null) |
|
|
|
obj = temp_obj; |
|
|
|
|
|
|
|
if (obj is Tensor tensor && allow_tensor) |
|
|
|
{ |
|
|
|
if (tensor.Graph.Equals(this)) |
|
|
|
{ |
|
|
|
return obj; |
|
|
|
return tensor; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
@@ -85,7 +90,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
if (op.Graph.Equals(this)) |
|
|
|
{ |
|
|
|
return obj; |
|
|
|
return op; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
@@ -93,7 +98,7 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); |
|
|
|
throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); |
|
|
|
} |
|
|
|
|
|
|
|
public void add_to_collection<T>(string name, T value) |
|
|
|