Browse Source

add _pop_tape.

tags/v0.20
Oceania2018 5 years ago
parent
commit
22d56abac8
3 changed files with 32 additions and 14 deletions
  1. +19
    -4
      src/TensorFlowNET.Core/Gradients/GradientActor.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Gradients/Tape.cs
  3. +8
    -10
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs

+ 19
- 4
src/TensorFlowNET.Core/Gradients/GradientActor.cs View File

@@ -48,6 +48,14 @@ namespace Tensorflow.Gradients
_recording = true;
}

private void _pop_tape()
{
if (!_recording)
throw new ValueError("Tape is not recording.");
_tape.pop_tape(_tape);
_recording = false;
}

/// <summary>
/// Marks this tensor to be watched by the given tape.
/// </summary>
@@ -59,12 +67,19 @@ namespace Tensorflow.Gradients

public Tensor gradient(Tensor target, Tensor sources)
{
using (var status = new Status())
if(_recording)
{
c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status);
if (!_persistent)
_pop_tape();
}
return null;

using var status = new Status();
var et = c_api.TFE_TapeGradient(_tape,
new IntPtr[] { (target as EagerTensor).EagerTensorHandle }, 1,
new IntPtr[] { (sources as EagerTensor).EagerTensorHandle }, 1,
status);
status.Check(true);
return et;
}

public void Dispose()


+ 5
- 0
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -20,6 +20,11 @@ namespace Tensorflow.Gradients
c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle);
}

public void pop_tape(Tape tape)
{
c_api.TFE_TapeSetRemove(tape);
}

public static bool IsDtypeTrainable(DataType dtype)
{
switch (dtype)


+ 8
- 10
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -715,17 +715,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
using (var status = new Status())
using var status = new Status();
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, new IntPtr[]
{
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, new IntPtr[]
{
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
}, 2, status);
status.Check(true);
return new EagerTensor(_result);
}
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
}, 2, status);
status.Check(true);
return new EagerTensor(_result);
}

var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });


Loading…
Cancel
Save