@@ -48,6 +48,14 @@ namespace Tensorflow.Gradients | |||||
_recording = true; | _recording = true; | ||||
} | } | ||||
private void _pop_tape() | |||||
{ | |||||
if (!_recording) | |||||
throw new ValueError("Tape is not recording."); | |||||
_tape.pop_tape(_tape); | |||||
_recording = false; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Marks this tensor to be watched by the given tape. | /// Marks this tensor to be watched by the given tape. | ||||
/// </summary> | /// </summary> | ||||
@@ -59,12 +67,19 @@ namespace Tensorflow.Gradients | |||||
public Tensor gradient(Tensor target, Tensor sources) | public Tensor gradient(Tensor target, Tensor sources) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
if(_recording) | |||||
{ | { | ||||
c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status); | |||||
if (!_persistent) | |||||
_pop_tape(); | |||||
} | } | ||||
return null; | |||||
using var status = new Status(); | |||||
var et = c_api.TFE_TapeGradient(_tape, | |||||
new IntPtr[] { (target as EagerTensor).EagerTensorHandle }, 1, | |||||
new IntPtr[] { (sources as EagerTensor).EagerTensorHandle }, 1, | |||||
status); | |||||
status.Check(true); | |||||
return et; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -20,6 +20,11 @@ namespace Tensorflow.Gradients | |||||
c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); | c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); | ||||
} | } | ||||
public void pop_tape(Tape tape) | |||||
{ | |||||
c_api.TFE_TapeSetRemove(tape); | |||||
} | |||||
public static bool IsDtypeTrainable(DataType dtype) | public static bool IsDtypeTrainable(DataType dtype) | ||||
{ | { | ||||
switch (dtype) | switch (dtype) | ||||
@@ -715,17 +715,15 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
using var status = new Status(); | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Mul", name, new IntPtr[] | |||||
{ | { | ||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Mul", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | ||||