@@ -6,12 +6,9 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public class Execute | public class Execute | ||||
{ | { | ||||
public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||||
public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||||
{ | { | ||||
if (inputs == null) | |||||
inputs = new Tensor[0]; | |||||
pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name); | |||||
pywrap_tfe_src.RecordGradient(op_name, inputs._inputs, attrs, results, name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -25,6 +25,9 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
} | } | ||||
if (!should_record) return; | if (!should_record) return; | ||||
var op_outputs = results; | |||||
var op_inputs = inputs; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class InputList | |||||
{ | |||||
public Tensor[] _inputs; | |||||
public InputList(Tensor[] inputs) | |||||
{ | |||||
_inputs = inputs; | |||||
} | |||||
} | |||||
} |
@@ -101,21 +101,23 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private Tensor[] _inputs; | |||||
public Tensor[] inputs | |||||
private InputList _inputs; | |||||
public InputList inputs | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
if(_inputs == null) | if(_inputs == null) | ||||
{ | { | ||||
_inputs = new Tensor[NumInputs]; | |||||
var retval = new Tensor[NumInputs]; | |||||
for (int i = 0; i < NumInputs; i++) | for (int i = 0; i < NumInputs; i++) | ||||
{ | { | ||||
var tf_outpus = Input(i); | var tf_outpus = Input(i); | ||||
var op = new Operation(tf_outpus.oper); | var op = new Operation(tf_outpus.oper); | ||||
_inputs[i] = op.outputs[tf_outpus.index]; | |||||
retval[i] = op.outputs[tf_outpus.index]; | |||||
} | } | ||||
_inputs = new InputList(retval); | |||||
} | } | ||||
return _inputs; | return _inputs; | ||||