diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
index af92e905..c9e3be84 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -28,9 +28,9 @@ namespace Tensorflow
///
/// The data input ops for an op to be created.
/// A list of control inputs for the op to be created.
- private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
+ private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
{
- Operation[] ret = new Operation[0];
+ var ret = new ITensorOrOperation[0];
foreach(var controller in _control_dependencies_stack)
{
@@ -54,12 +54,12 @@ namespace Tensorflow
return ret;
}
- public _ControlDependenciesController control_dependencies(Operation[] control_inputs)
+ public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
{
if (control_inputs == null)
return new _ControlDependenciesController(this, null);
- var control_ops = new List();
+ var control_ops = new List();
foreach (var c in control_inputs)
{
control_ops.Add(c);
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index aa3eb26e..c9d93182 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -298,6 +298,11 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}
+ public string[] get_all_collection_keys()
+ {
+ return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
+ }
+
public object get_collection(string name, string scope = "")
{
return _collections.ContainsKey(name) ? _collections[name] : null;
diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
index 08302cc1..f1ddcb44 100644
--- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
+++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
@@ -11,20 +11,20 @@ namespace Tensorflow
public class _ControlDependenciesController : IPython
{
private Graph _graph;
- private List _control_inputs_val;
- private List _seen_nodes;
+ private List _control_inputs_val;
+ private List _seen_nodes;
private Queue<_ControlDependenciesController> _old_stack;
private bool _new_stack;
private Context _old_control_flow_context;
- public Operation[] control_inputs => _control_inputs_val.ToArray();
+ public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();
- public _ControlDependenciesController(Graph graph, List control_inputs)
+ public _ControlDependenciesController(Graph graph, List control_inputs)
{
_graph = graph;
if (control_inputs == null)
{
- _control_inputs_val = new List();
+ _control_inputs_val = new List();
_new_stack = true;
}
else
@@ -33,15 +33,15 @@ namespace Tensorflow
_new_stack = false;
}
- _seen_nodes = new List();
+ _seen_nodes = new List();
}
- public void add_op(Operation op)
+ public void add_op(ITensorOrOperation op)
{
_seen_nodes.Add(op);
}
- public bool op_in_group(Operation op)
+ public bool op_in_group(ITensorOrOperation op)
{
return _seen_nodes.Contains(op);
}
diff --git a/src/TensorFlowNET.Core/ITensorOrOperation.cs b/src/TensorFlowNET.Core/ITensorOrOperation.cs
index c29713b9..f12a0b02 100644
--- a/src/TensorFlowNET.Core/ITensorOrOperation.cs
+++ b/src/TensorFlowNET.Core/ITensorOrOperation.cs
@@ -11,5 +11,6 @@ namespace Tensorflow
public interface ITensorOrOperation
{
string Device { get; }
+ Operation op { get; }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
index 682f59ec..76dd318a 100644
--- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
+++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
@@ -107,7 +107,9 @@ namespace Tensorflow
values = ops.internal_convert_to_tensor(values,
name: input_name,
- as_ref: input_arg.IsRef);
+ dtype: dtype,
+ as_ref: input_arg.IsRef,
+ preferred_dtype: default_dtype);
//if (!String.IsNullOrEmpty(input_arg.TypeAttr))
//attrs[input_arg.TypeAttr] = values.dtype;
@@ -163,14 +165,20 @@ namespace Tensorflow
foreach (var arg in op_def.OutputArg)
{
+ types = new List();
if (!string.IsNullOrEmpty(arg.NumberAttr))
{
}
else if (!string.IsNullOrEmpty(arg.TypeAttr))
{
- output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
+ types = new List() { (TF_DataType)attr_protos[arg.TypeAttr].Type };
}
+
+ if (arg.IsRef)
+ types = types.Select(x => x.as_ref()).ToList();
+
+ output_types.AddRange(types);
}
// Add Op to graph
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index e6450393..0549edf0 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -16,6 +16,7 @@ namespace Tensorflow
private int _id_value;
public string type => OpType;
+ public Operation op => this;
private Status status = new Status();
@@ -75,7 +76,7 @@ namespace Tensorflow
///
///
///
- public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
+ public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
Graph = g;
@@ -120,6 +121,11 @@ namespace Tensorflow
_control_flow_post_processing();
}
+ public void run(FeedItem[] feed_dict = null, Session session = null)
+ {
+ ops._run_using_default_session(this, feed_dict, Graph, session);
+ }
+
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs)
{
var grouped_inputs = new List