diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs
index 2a802fd7..1d274dde 100644
--- a/src/TensorFlowNET.Core/Operations/InputList.cs
+++ b/src/TensorFlowNET.Core/Operations/InputList.cs
@@ -1,16 +1,23 @@
using System;
+using System.Collections;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow
{
- public class InputList
+ public class InputList : IEnumerable
{
public Tensor[] _inputs;
+ public Tensor this[int index] => _inputs[index];
public InputList(Tensor[] inputs)
{
_inputs = inputs;
}
+
+ public IEnumerator GetEnumerator()
+ {
+ return _inputs.GetEnumerator();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
new file mode 100644
index 00000000..5599ad2b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public partial class Operation
+ {
+ ///
+ /// Add this op to its control flow context.
+ ///
+ public void _control_flow_post_processing()
+ {
+ foreach(var input_tensor in inputs)
+ {
+
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
index 5fdc6d47..64c38c16 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
@@ -12,21 +12,7 @@ namespace Tensorflow
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
private Tensor[] _outputs;
- public Tensor[] outputs
- {
- get
- {
- if (_outputs == null)
- {
- _outputs = new Tensor[NumOutputs];
-
- for (int i = 0; i < NumOutputs; i++)
- _outputs[i] = new Tensor(this, i, OutputType(i));
- }
-
- return _outputs;
- }
- }
+ public Tensor[] outputs => _outputs;
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index dfc89e67..284f9ee6 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -8,7 +8,7 @@ namespace Tensorflow
{
public partial class Operation
{
- private readonly IntPtr _handle;
+ private readonly IntPtr _handle; // _c_op in python
public Graph Graph { get; }
public int _id => _id_value;
@@ -97,12 +97,20 @@ namespace Tensorflow
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
+ // Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i);
+ _outputs = new Tensor[NumOutputs];
+ for (int i = 0; i < NumOutputs; i++)
+ _outputs[i] = new Tensor(this, i, OutputType(i));
+
Graph._add_op(this);
+
+ if (_handle != IntPtr.Zero)
+ _control_flow_post_processing();
}
public object get_attr(string name)
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 8767d5ea..b5e68ae2 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -18,7 +18,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
public static Tensor sub(Tensor x, Tensor y)
@@ -29,7 +29,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
public static Tensor mul(Tensor x, Tensor y)
@@ -40,7 +40,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
public static Tensor real_div(Tensor x, Tensor y)
@@ -51,7 +51,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
@@ -64,7 +64,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
public static Tensor pow(Tensor x, double y)
@@ -75,7 +75,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
public static Tensor sum(Tensor input, Tensor axis = null)
@@ -87,7 +87,7 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords);
- return new Tensor(_op, 0, _op.OutputType(0));
+ return _op.outputs[0];
}
///
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
index 5e7fe2a5..8345d0b7 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
@@ -23,6 +23,12 @@ namespace Tensorflow
public static implicit operator RefVariable(Tensor var)
{
+ switch (var.dtype)
+ {
+ case TF_DataType.TF_INT32:
+ return tf.Variable(var.Data()[0]);
+ }
+
return null;
}
}