Browse Source

Add PersistentTape unit test.

tags/v0.20
Oceania2018 5 years ago
parent
commit
91fb30f3d9
5 changed files with 26 additions and 43 deletions
  1. +9
    -3
      src/TensorFlowNET.Core/APIs/tf.gradients.cs
  2. +0
    -36
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  4. +6
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  5. +3
    -3
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs

+ 9
- 3
src/TensorFlowNET.Core/APIs/tf.gradients.cs View File

@@ -1,5 +1,5 @@
/***************************************************************************** /*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved.


Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@@ -20,9 +20,15 @@ namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
/// <summary>
/// Record operations for automatic differentiation.
/// </summary>
/// <param name="persistent"></param>
/// <param name="watch_accessed_variables"></param>
/// <returns></returns>
public GradientTape GradientTape(bool persistent = false, public GradientTape GradientTape(bool persistent = false,
bool watch_accessed_variables = true)
=> new GradientTape(persistent: persistent,
bool watch_accessed_variables = true)
=> new GradientTape(persistent: persistent,
watch_accessed_variables: watch_accessed_variables); watch_accessed_variables: watch_accessed_variables);


public Tensor[] gradients(Tensor[] ys, public Tensor[] gradients(Tensor[] ys,


+ 0
- 36
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -389,42 +389,6 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx);


/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="device_name"></param>
/// <param name="op_name"></param>
/// <param name="name"></param>
/// <param name="args"></param>
/// <param name="input_size"></param>
/// <param name="set_op_attrs"></param>
/// <param name="status"></param>
/// <returns>EagerTensorHandle</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx,
string device_name,
string op_name,
string name,
IntPtr[] inputs,
int input_size,
string attrs_string,
TFE_FastPathExecute_SetOpAttrs set_op_attrs,
IntPtr[] outputs,
int output_size);
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op);

[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx,
string device_name,
string op_name,
IntPtr[] inputs,
int input_size,
TFE_FastPathExecute_SetOpAttrs set_op_attrs,
IntPtr[] outputs,
int output_size);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);




+ 8
- 0
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -142,6 +142,14 @@ namespace Tensorflow.Gradients
return results; return results;
} }


/// <summary>
/// Temporarily stops recording operations on this tape.
/// </summary>
public void stop_recording()
{
_pop_tape();
}

public void Dispose() public void Dispose()
{ {
if (_recording) if (_recording)


+ 6
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -328,7 +328,7 @@ namespace Tensorflow
{ {
dtype = dtype.as_base_dtype(); dtype = dtype.as_base_dtype();
name = scope; name = scope;
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
Tensor ones = null; Tensor ones = null;
switch (dtype) switch (dtype)
{ {
@@ -342,6 +342,11 @@ namespace Tensorflow
ones = constant(1); ones = constant(1);
break; break;
} }

if (shape.ndim == 0)
return ones;
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
return fill(shape_tensor, ones, name: name); return fill(shape_tensor, ones, name: name);
}); });




+ 3
- 3
test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs View File

@@ -46,15 +46,15 @@ namespace TensorFlowNET.UnitTest.Gradient
tape.watch(x); tape.watch(x);
var y = tf.reduce_sum(x); var y = tf.reduce_sum(x);
var z = tf.multiply(y, y); var z = tf.multiply(y, y);
tape.Dispose();

var dz_dx = tape.gradient(z, x); var dz_dx = tape.gradient(z, x);


var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));


var dz_dy = tape.gradient(z, y); var dz_dy = tape.gradient(z, y);

expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
Assert.AreEqual((float)dz_dy, 8.0f);
} }
} }
} }

Loading…
Cancel
Save