diff --git a/src/TensorFlowNET.Core/Eager/Execute.cs b/src/TensorFlowNET.Core/Eager/Execute.cs index ed8c7839..60aeb4f9 100644 --- a/src/TensorFlowNET.Core/Eager/Execute.cs +++ b/src/TensorFlowNET.Core/Eager/Execute.cs @@ -6,12 +6,9 @@ namespace Tensorflow.Eager { public class Execute { - public void record_gradient(string op_name, Tensor[] inputs, Dictionary attrs, Tensor[] results, string name = "") + public void record_gradient(string op_name, InputList inputs, Dictionary attrs, Tensor[] results, string name = "") { - if (inputs == null) - inputs = new Tensor[0]; - - pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name); + pywrap_tfe_src.RecordGradient(op_name, inputs._inputs, attrs, results, name); } } } diff --git a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs index 79dc67a8..b0ebd2f8 100644 --- a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs +++ b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs @@ -25,6 +25,9 @@ namespace Tensorflow.Eager } } if (!should_record) return; + + var op_outputs = results; + var op_inputs = inputs; } } } diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs new file mode 100644 index 00000000..2a802fd7 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/InputList.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class InputList + { + public Tensor[] _inputs; + + public InputList(Tensor[] inputs) + { + _inputs = inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 00ffb76f..f1306fc0 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -101,21 +101,23 @@ namespace Tensorflow } } - private Tensor[] _inputs; - public Tensor[] inputs + private InputList _inputs; + public InputList inputs { get { if(_inputs == null) { - _inputs = new Tensor[NumInputs]; + var retval = new Tensor[NumInputs]; for (int i = 0; i < NumInputs; i++) { var tf_outpus = Input(i); var op = new Operation(tf_outpus.oper); - _inputs[i] = op.outputs[tf_outpus.index]; + retval[i] = op.outputs[tf_outpus.index]; } + + _inputs = new InputList(retval); } return _inputs;