Browse Source

Add PersistentTape unit test.

tags/v0.20
Oceania2018 5 years ago
parent
commit
91fb30f3d9
5 changed files with 26 additions and 43 deletions
  1. +9
    -3
      src/TensorFlowNET.Core/APIs/tf.gradients.cs
  2. +0
    -36
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  4. +6
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  5. +3
    -3
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs

+ 9
- 3
src/TensorFlowNET.Core/APIs/tf.gradients.cs View File

@@ -1,5 +1,5 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,9 +20,15 @@ namespace Tensorflow
{
public partial class tensorflow
{
/// <summary>
/// Record operations for automatic differentiation.
/// </summary>
/// <param name="persistent"></param>
/// <param name="watch_accessed_variables"></param>
/// <returns></returns>
public GradientTape GradientTape(bool persistent = false,
bool watch_accessed_variables = true)
=> new GradientTape(persistent: persistent,
bool watch_accessed_variables = true)
=> new GradientTape(persistent: persistent,
watch_accessed_variables: watch_accessed_variables);

public Tensor[] gradients(Tensor[] ys,


+ 0
- 36
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -389,42 +389,6 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="device_name"></param>
/// <param name="op_name"></param>
/// <param name="name"></param>
/// <param name="args"></param>
/// <param name="input_size"></param>
/// <param name="set_op_attrs"></param>
/// <param name="status"></param>
/// <returns>EagerTensorHandle</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx,
string device_name,
string op_name,
string name,
IntPtr[] inputs,
int input_size,
string attrs_string,
TFE_FastPathExecute_SetOpAttrs set_op_attrs,
IntPtr[] outputs,
int output_size);
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op);

[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx,
string device_name,
string op_name,
IntPtr[] inputs,
int input_size,
TFE_FastPathExecute_SetOpAttrs set_op_attrs,
IntPtr[] outputs,
int output_size);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);



+ 8
- 0
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -142,6 +142,14 @@ namespace Tensorflow.Gradients
return results;
}

/// <summary>
/// Temporarily stops recording operations on this tape.
/// </summary>
public void stop_recording()
{
_pop_tape();
}

public void Dispose()
{
if (_recording)


+ 6
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -328,7 +328,7 @@ namespace Tensorflow
{
dtype = dtype.as_base_dtype();
name = scope;
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
Tensor ones = null;
switch (dtype)
{
@@ -342,6 +342,11 @@ namespace Tensorflow
ones = constant(1);
break;
}

if (shape.ndim == 0)
return ones;
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
return fill(shape_tensor, ones, name: name);
});



+ 3
- 3
test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs View File

@@ -46,15 +46,15 @@ namespace TensorFlowNET.UnitTest.Gradient
tape.watch(x);
var y = tf.reduce_sum(x);
var z = tf.multiply(y, y);
tape.Dispose();

var dz_dx = tape.gradient(z, x);

var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));

var dz_dy = tape.gradient(z, y);

expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
Assert.AreEqual((float)dz_dy, 8.0f);
}
}
}

Loading…
Cancel
Save