@@ -1,5 +1,5 @@ | |||||
/***************************************************************************** | /***************************************************************************** | ||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
you may not use this file except in compliance with the License. | you may not use this file except in compliance with the License. | ||||
@@ -20,9 +20,15 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
/// <summary> | |||||
/// Record operations for automatic differentiation. | |||||
/// </summary> | |||||
/// <param name="persistent"></param> | |||||
/// <param name="watch_accessed_variables"></param> | |||||
/// <returns></returns> | |||||
public GradientTape GradientTape(bool persistent = false, | public GradientTape GradientTape(bool persistent = false, | ||||
bool watch_accessed_variables = true) | |||||
=> new GradientTape(persistent: persistent, | |||||
bool watch_accessed_variables = true) | |||||
=> new GradientTape(persistent: persistent, | |||||
watch_accessed_variables: watch_accessed_variables); | watch_accessed_variables: watch_accessed_variables); | ||||
public Tensor[] gradients(Tensor[] ys, | public Tensor[] gradients(Tensor[] ys, | ||||
@@ -389,42 +389,6 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="ctx"></param> | |||||
/// <param name="device_name"></param> | |||||
/// <param name="op_name"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="args"></param> | |||||
/// <param name="input_size"></param> | |||||
/// <param name="set_op_attrs"></param> | |||||
/// <param name="status"></param> | |||||
/// <returns>EagerTensorHandle</returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, | |||||
string device_name, | |||||
string op_name, | |||||
string name, | |||||
IntPtr[] inputs, | |||||
int input_size, | |||||
string attrs_string, | |||||
TFE_FastPathExecute_SetOpAttrs set_op_attrs, | |||||
IntPtr[] outputs, | |||||
int output_size); | |||||
[UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||||
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, | |||||
string device_name, | |||||
string op_name, | |||||
IntPtr[] inputs, | |||||
int input_size, | |||||
TFE_FastPathExecute_SetOpAttrs set_op_attrs, | |||||
IntPtr[] outputs, | |||||
int output_size); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | ||||
@@ -142,6 +142,14 @@ namespace Tensorflow.Gradients | |||||
return results; | return results; | ||||
} | } | ||||
/// <summary> | |||||
/// Temporarily stops recording operations on this tape. | |||||
/// </summary> | |||||
public void stop_recording() | |||||
{ | |||||
_pop_tape(); | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
if (_recording) | if (_recording) | ||||
@@ -328,7 +328,7 @@ namespace Tensorflow | |||||
{ | { | ||||
dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
name = scope; | name = scope; | ||||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
Tensor ones = null; | Tensor ones = null; | ||||
switch (dtype) | switch (dtype) | ||||
{ | { | ||||
@@ -342,6 +342,11 @@ namespace Tensorflow | |||||
ones = constant(1); | ones = constant(1); | ||||
break; | break; | ||||
} | } | ||||
if (shape.ndim == 0) | |||||
return ones; | |||||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
return fill(shape_tensor, ones, name: name); | return fill(shape_tensor, ones, name: name); | ||||
}); | }); | ||||
@@ -46,15 +46,15 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
tape.watch(x); | tape.watch(x); | ||||
var y = tf.reduce_sum(x); | var y = tf.reduce_sum(x); | ||||
var z = tf.multiply(y, y); | var z = tf.multiply(y, y); | ||||
tape.Dispose(); | |||||
var dz_dx = tape.gradient(z, x); | var dz_dx = tape.gradient(z, x); | ||||
var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | ||||
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | ||||
var dz_dy = tape.gradient(z, y); | var dz_dy = tape.gradient(z, y); | ||||
expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||||
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||||
Assert.AreEqual((float)dz_dy, 8.0f); | |||||
} | } | ||||
} | } | ||||
} | } |