From 91fb30f3d980c2f7073f092f7a59b157cce4b83d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 28 Jun 2020 20:16:48 -0500 Subject: [PATCH] Add PersistentTape unit test. --- src/TensorFlowNET.Core/APIs/tf.gradients.cs | 12 +++++-- src/TensorFlowNET.Core/Eager/c_api.eager.cs | 36 ------------------- .../Gradients/GradientTape.cs | 8 +++++ .../Operations/array_ops.cs | 7 +++- .../NativeAPI/Eager/GradientEagerTest.cs | 6 ++-- 5 files changed, 26 insertions(+), 43 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs index ef13b3f4..bb648bb9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -1,5 +1,5 @@ /***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,9 +20,15 @@ namespace Tensorflow { public partial class tensorflow { + /// + /// Record operations for automatic differentiation. + /// + /// + /// + /// public GradientTape GradientTape(bool persistent = false, - bool watch_accessed_variables = true) - => new GradientTape(persistent: persistent, + bool watch_accessed_variables = true) + => new GradientTape(persistent: persistent, watch_accessed_variables: watch_accessed_variables); public Tensor[] gradients(Tensor[] ys, diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index af63dd67..1b2fb6b8 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -389,42 +389,6 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// EagerTensorHandle - [DllImport(TensorFlowLibName)] - public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, - string device_name, - string op_name, - string name, - IntPtr[] inputs, - int input_size, - string attrs_string, - TFE_FastPathExecute_SetOpAttrs set_op_attrs, - IntPtr[] outputs, - int output_size); - [UnmanagedFunctionPointer(CallingConvention.StdCall)] - public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); - - [DllImport(TensorFlowLibName)] - public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, - string device_name, - string op_name, - IntPtr[] inputs, - int input_size, - TFE_FastPathExecute_SetOpAttrs set_op_attrs, - IntPtr[] outputs, - int output_size); - [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index 32a5d415..69bf264f 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -142,6 +142,14 @@ namespace Tensorflow.Gradients return results; } + /// + /// Temporarily stops recording operations on this tape. + /// + public void stop_recording() + { + _pop_tape(); + } + public void Dispose() { if (_recording) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index dab9d3ec..4cb55119 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -328,7 +328,7 @@ namespace Tensorflow { dtype = dtype.as_base_dtype(); name = scope; - var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); + Tensor ones = null; switch (dtype) { @@ -342,6 +342,11 @@ namespace Tensorflow ones = constant(1); break; } + + if (shape.ndim == 0) + return ones; + + var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); return fill(shape_tensor, ones, name: name); }); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs index c2d61e1f..7c24ee26 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs @@ -46,15 +46,15 @@ namespace TensorFlowNET.UnitTest.Gradient tape.watch(x); var y = tf.reduce_sum(x); var z = tf.multiply(y, y); + tape.Dispose(); + var dz_dx = tape.gradient(z, x); var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); var dz_dy = tape.gradient(z, y); - - expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; - Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); + Assert.AreEqual((float)dz_dy, 8.0f); } } }