diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs
index ef13b3f4..bb648bb9 100644
--- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs
@@ -1,5 +1,5 @@
/*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+ Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,9 +20,15 @@ namespace Tensorflow
{
public partial class tensorflow
{
+ ///
+ /// Record operations for automatic differentiation.
+ ///
+ ///
+ ///
+ ///
public GradientTape GradientTape(bool persistent = false,
- bool watch_accessed_variables = true)
- => new GradientTape(persistent: persistent,
+ bool watch_accessed_variables = true)
+ => new GradientTape(persistent: persistent,
watch_accessed_variables: watch_accessed_variables);
public Tensor[] gradients(Tensor[] ys,
diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index af63dd67..1b2fb6b8 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -389,42 +389,6 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx);
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- /// EagerTensorHandle
- [DllImport(TensorFlowLibName)]
- public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx,
- string device_name,
- string op_name,
- string name,
- IntPtr[] inputs,
- int input_size,
- string attrs_string,
- TFE_FastPathExecute_SetOpAttrs set_op_attrs,
- IntPtr[] outputs,
- int output_size);
- [UnmanagedFunctionPointer(CallingConvention.StdCall)]
- public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op);
-
- [DllImport(TensorFlowLibName)]
- public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx,
- string device_name,
- string op_name,
- IntPtr[] inputs,
- int input_size,
- TFE_FastPathExecute_SetOpAttrs set_op_attrs,
- IntPtr[] outputs,
- int output_size);
-
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);
diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs
index 32a5d415..69bf264f 100644
--- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs
+++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs
@@ -142,6 +142,14 @@ namespace Tensorflow.Gradients
return results;
}
+ ///
+ /// Temporarily stops recording operations on this tape.
+ ///
+ public void stop_recording()
+ {
+ _pop_tape();
+ }
+
public void Dispose()
{
if (_recording)
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index dab9d3ec..4cb55119 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -328,7 +328,7 @@ namespace Tensorflow
{
dtype = dtype.as_base_dtype();
name = scope;
- var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
+
Tensor ones = null;
switch (dtype)
{
@@ -342,6 +342,11 @@ namespace Tensorflow
ones = constant(1);
break;
}
+
+ if (shape.ndim == 0)
+ return ones;
+
+ var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
return fill(shape_tensor, ones, name: name);
});
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs
index c2d61e1f..7c24ee26 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs
+++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs
@@ -46,15 +46,15 @@ namespace TensorFlowNET.UnitTest.Gradient
tape.watch(x);
var y = tf.reduce_sum(x);
var z = tf.multiply(y, y);
+ tape.Dispose();
+
var dz_dx = tape.gradient(z, x);
var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected));
var dz_dy = tape.gradient(z, y);
-
- expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
- Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected));
+ Assert.AreEqual((float)dz_dy, 8.0f);
}
}
}