diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs index 2a802fd7..1d274dde 100644 --- a/src/TensorFlowNET.Core/Operations/InputList.cs +++ b/src/TensorFlowNET.Core/Operations/InputList.cs @@ -1,16 +1,23 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow { - public class InputList + public class InputList : IEnumerable { public Tensor[] _inputs; + public Tensor this[int index] => _inputs[index]; public InputList(Tensor[] inputs) { _inputs = inputs; } + + public IEnumerator GetEnumerator() + { + return _inputs.GetEnumerator(); + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs new file mode 100644 index 00000000..5599ad2b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Operation + { + /// + /// Add this op to its control flow context. + /// + public void _control_flow_post_processing() + { + foreach(var input_tensor in inputs) + { + + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 5fdc6d47..64c38c16 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -12,21 +12,7 @@ namespace Tensorflow public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); private Tensor[] _outputs; - public Tensor[] outputs - { - get - { - if (_outputs == null) - { - _outputs = new Tensor[NumOutputs]; - - for (int i = 0; i < NumOutputs; i++) - _outputs[i] = new Tensor(this, i, OutputType(i)); - } - - return _outputs; - } - } + public Tensor[] outputs => _outputs; public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index dfc89e67..284f9ee6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -8,7 +8,7 @@ namespace Tensorflow { public partial class Operation { - private readonly IntPtr _handle; + private readonly IntPtr _handle; // _c_op in python public Graph Graph { get; } public int _id => _id_value; @@ -97,12 +97,20 @@ namespace Tensorflow _handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); + // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) output_types[i] = OutputType(i); + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); + Graph._add_op(this); + + if (_handle != IntPtr.Zero) + _control_flow_post_processing(); } public object get_attr(string name) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 8767d5ea..b5e68ae2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -18,7 +18,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } public static Tensor sub(Tensor x, Tensor y) @@ -29,7 +29,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } public static Tensor mul(Tensor x, Tensor y) @@ -40,7 +40,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } public static Tensor real_div(Tensor x, Tensor y) @@ -51,7 +51,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) @@ -64,7 +64,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } public static Tensor pow(Tensor x, double y) @@ -75,7 +75,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } public static Tensor sum(Tensor input, Tensor axis = null) @@ -87,7 +87,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords); - return new Tensor(_op, 0, _op.OutputType(0)); + return _op.outputs[0]; } /// diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs index 5e7fe2a5..8345d0b7 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs @@ -23,6 +23,12 @@ namespace Tensorflow public static implicit operator RefVariable(Tensor var) { + switch (var.dtype) + { + case TF_DataType.TF_INT32: + return tf.Variable(var.Data()[0]); + } + return null; } }