|
|
@@ -11,7 +11,8 @@ namespace Tensorflow.Eager |
|
|
|
public bool RecordGradient(string op_name, |
|
|
|
Tensor[] inputs, |
|
|
|
object[] attrs, |
|
|
|
Tensor[] results) |
|
|
|
Tensor[] results, |
|
|
|
Func<BackwardFunction> getBackwardFunction = null) |
|
|
|
{ |
|
|
|
var input_ids = MakeTensorIDList(inputs); |
|
|
|
var input_dtypes = MakeTensorDtypeList(inputs); |
|
|
@@ -77,13 +78,20 @@ namespace Tensorflow.Eager |
|
|
|
else |
|
|
|
op_inputs = inputs; |
|
|
|
|
|
|
|
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, |
|
|
|
() => GetGradientFunction(op_name, inputs, attrs, results)); |
|
|
|
|
|
|
|
TapeSetRecordOperation(op_name, inputs, results, |
|
|
|
getBackwardFunction ?? GetBackwradFunction(op_name, inputs, attrs, results)); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
Func<BackwardFunction> GetBackwradFunction(string op_name, |
|
|
|
Tensor[] op_inputs, |
|
|
|
object[] attrs, |
|
|
|
Tensor[] op_outputs) |
|
|
|
{ |
|
|
|
return () => GetGradientFunction(op_name, op_inputs, attrs, op_outputs); |
|
|
|
} |
|
|
|
|
|
|
|
BackwardFunction GetGradientFunction(string op_name, |
|
|
|
Tensor[] op_inputs, |
|
|
|
object[] attrs, |
|
|
|