Browse Source

_AddNGrad

tags/v0.20
Oceania2018 5 years ago
parent
commit
43ab3cad88
11 changed files with 128 additions and 70 deletions
  1. +18
    -10
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  2. +1
    -4
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs
  5. +16
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  6. +20
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  7. +3
    -0
      src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Util/UnorderedMap.cs
  9. +16
    -51
      src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs
  10. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  11. +47
    -0
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs

+ 18
- 10
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -64,10 +64,8 @@ namespace Tensorflow.Eager
}
}

var flattened_inputs = args.Take(op_def.InputArg.Count)
.Select(x => x as Tensor)
.ToArray();
var flattened_attrs = args.Skip(op_def.InputArg.Count).ToArray();
var flattened_attrs = new List<object>(op_def.InputArg.Count);
var flattened_inputs = new List<Tensor>(op_def.InputArg.Count);

c_api.TFE_OpSetDevice(op, device_name, status.Handle);
status.Check(true);
@@ -80,31 +78,36 @@ namespace Tensorflow.Eager
{
int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length;
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
if (op_exec_info.run_callbacks)
{
flattened_attrs.Add(input_arg.NumberAttr);
flattened_attrs.Add(len);
}
attr_list_sizes[input_arg.NumberAttr] = len;

if (len > 0)
{
var fast_input_array = (object[])args[i];
// First item adds the type attr.
if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status))
if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
return null;

for (var j = 1; j < len; j++)
{
// Since the list is homogeneous, we don't need to re-add the attr.
if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status))
if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status))
return null;
}
}
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
throw new NotImplementedException("");
}
else
{
// The item is a single item.
AddInputToOp(args[i], true, input_arg, op, status);
AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
}
}

@@ -133,7 +136,7 @@ namespace Tensorflow.Eager
if (!RunCallbacks(
op_exec_info,
kFastPathExecuteInputStartIndex + op_def.InputArg.Count(),
flattened_inputs, flattened_attrs, flat_result))
flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result))
{
return null;
}
@@ -187,6 +190,8 @@ namespace Tensorflow.Eager
bool AddInputToOp(object inputs,
bool add_type_attr,
ArgDef input_arg,
List<object> flattened_attrs,
List<Tensor> flattened_inputs,
IntPtr op,
Status status)
{
@@ -197,6 +202,7 @@ namespace Tensorflow.Eager
{
case EagerTensor input:
input_handle = input.EagerTensorHandle;
flattened_inputs.Add(input);
break;
case EagerTensor[] input_list:
input_handle = input_list[0].EagerTensorHandle;
@@ -211,6 +217,8 @@ namespace Tensorflow.Eager
{
var dtype = c_api.TFE_TensorHandleDataType(input_handle);
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype);
flattened_attrs.Add(input_arg.TypeAttr);
flattened_attrs.Add(dtype);
}

c_api.TFE_OpAddInput(op, input_handle, status.Handle);


+ 1
- 4
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow.Eager

public EagerTensor Resolve()
{
_id = get_uid();
_id = ops.uid();

if (_handle == IntPtr.Zero)
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status.Handle);
@@ -55,8 +55,5 @@ namespace Tensorflow.Eager
//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
c_api.TFE_DeleteTensorHandle(EagerTensorHandle);
}

static long _uid = 0;
long get_uid() => _uid++;
}
}

+ 2
- 2
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -24,8 +24,8 @@ namespace Tensorflow.Gradients
/// </summary>
public class GradientTape : IDisposable
{
static bool _recording;
public static bool Recording => _recording;
bool _recording;
public bool Recording => _recording;
bool _persistent;
bool _watch_accessed_variables;
ResourceVariable[] _watched_variables;


+ 2
- 0
src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs View File

@@ -13,12 +13,14 @@ namespace Tensorflow.Gradients
"FusedBatchNormGradV3" => new[] { 5 },
"FusedBatchNormV2" => new[] { 2 },
"FusedBatchNormV3" => new[] { 2 },
"ReadVariableOp" => new int[0],
_ => null
};

public static int[] OpGradientUnusedOutputIndices(string op_name)
=> op_name switch
{
"ReadVariableOp" => new int[0],
"SoftmaxCrossEntropyWithLogits" => new[] { 0 },
"TensorArrayConcat" => new[] { 0 },
"TensorArrayConcatV2" => new[] { 0 },


+ 16
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -64,6 +64,22 @@ namespace Tensorflow.Gradients
return new Tensor[] { r1, r2 };
}

/// <summary>
/// Copies the gradient to all inputs.
/// </summary>
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("AddN")]
public static Tensor[] _AddNGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];

return Enumerable.Range(0, len(op.inputs))
.Select(x => grad)
.ToArray();
}

[RegisterGradient("Cumsum")]
public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads)
{


+ 20
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -124,6 +124,16 @@ namespace Tensorflow
/// </remarks>
public static Tensor diag(Tensor diagonal, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Diag", name,
null,
diagonal);

return results[0];
}

var op = tf._op_def_lib._apply_op_helper("Diag", name: name, args: new { diagonal });

return op.output;
@@ -131,6 +141,16 @@ namespace Tensorflow

public static Tensor expand_dims(Tensor input, int axis, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"ExpandDims", name,
null,
input, tf.convert_to_tensor(axis));

return results[0];
}

var _op = tf._op_def_lib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis });

return _op.outputs[0];


+ 3
- 0
src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs View File

@@ -39,5 +39,8 @@ namespace Tensorflow.Util
}

public override bool IsInvalid => handle == IntPtr.Zero;

public override string ToString()
=> $"0x{handle.ToString("x16")}";
}
}

+ 2
- 2
src/TensorFlowNET.Core/Util/UnorderedMap.cs View File

@@ -28,10 +28,10 @@ namespace Tensorflow.Util
}

public void push_back(Tk key, Tv value)
=> Add(key, value);
=> this[key] = value;

public void emplace(Tk key, Tv value)
=> Add(key, value);
=> this[key] = value;

public bool find(Tk key)
=> ContainsKey(key);


+ 16
- 51
src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs View File

@@ -22,56 +22,21 @@ namespace Tensorflow
{
public partial class ResourceVariable
{
public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y);
public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y);
public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y);
public static Tensor operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y);
public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y);
public static Tensor operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y);

public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y);
public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y);

public static Tensor operator <(ResourceVariable x, Tensor y) => op_helper("less", x, y);

public static Tensor operator >(ResourceVariable x, Tensor y) => op_helper("greater", x, y);

private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y)
=> tf_with(ops.name_scope(null, default_name, new { x, y }), scope =>
{
string name = scope;
var xVal = x.value();
var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y");
Tensor result = null;
switch (default_name)
{
case "add":
result = x.dtype == TF_DataType.TF_STRING ?
gen_math_ops.add(xVal, yTensor, name) :
gen_math_ops.add_v2(xVal, yTensor, name);
break;
case "sub":
result = gen_math_ops.sub(xVal, yTensor, name);
break;
case "mul":
result = gen_math_ops.mul(xVal, yTensor, name: name);
break;
case "less":
result = gen_math_ops.less(xVal, yTensor, name);
break;
case "greater":
result = gen_math_ops.greater(xVal, yTensor, name);
break;
default:
throw new NotImplementedException("");
}

// x.assign(result);
// result.ResourceVar = x;
return result;
});
public static Tensor operator +(ResourceVariable x, int y) => x.value() + y;
public static Tensor operator +(ResourceVariable x, float y) => x.value() + y;
public static Tensor operator +(ResourceVariable x, double y) => x.value() + y;
public static Tensor operator +(ResourceVariable x, ResourceVariable y) => x.value() + y.value();
public static Tensor operator -(ResourceVariable x, int y) => x.value() - y;
public static Tensor operator -(ResourceVariable x, float y) => x.value() - y;
public static Tensor operator -(ResourceVariable x, double y) => x.value() - y;
public static Tensor operator -(ResourceVariable x, Tensor y) => x.value() - y;
public static Tensor operator -(ResourceVariable x, ResourceVariable y) => x.value() - y.value();

public static Tensor operator *(ResourceVariable x, ResourceVariable y) => x.value() * y.value();
public static Tensor operator *(ResourceVariable x, NDArray y) => x.value() * y;

public static Tensor operator <(ResourceVariable x, Tensor y) => x.value() < y;

public static Tensor operator >(ResourceVariable x, Tensor y) => x.value() > y;
}
}

+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -277,7 +277,7 @@ namespace Tensorflow
return ops.control_dependencies(null);
}

private static int uid_number = 0;
private static int uid_number = -1;

/// <summary>
/// A unique (within this program execution) integer.


+ 47
- 0
test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs View File

@@ -24,6 +24,25 @@ namespace TensorFlowNET.UnitTest.Gradient
Assert.AreEqual((float)grad, 3.0f);
}

/// <summary>
/// Calcute the gradient of w * w * w
/// 高阶梯度
/// </summary>
[TestMethod]
public void HighGradient()
{
var x = tf.Variable(1.0f);
using var tape1 = tf.GradientTape();
using var tape2 = tf.GradientTape();
var y = x * x * x;
tape2.Dispose();
var dy_dx = tape2.gradient(y, x);
Assert.AreEqual((float)dy_dx, 3.0f);
tape1.Dispose();
var d2y_d2x = tape1.gradient(dy_dx, x);
Assert.AreEqual((float)d2y_d2x, 6.0f);
}

[TestMethod]
public void ConstantMultiply()
{
@@ -56,5 +75,33 @@ namespace TensorFlowNET.UnitTest.Gradient
var dz_dy = tape.gradient(z, y);
Assert.AreEqual((float)dz_dy, 8.0f);
}

[TestMethod]
public void ConditionalMultiply()
{
Func<Tensor, int, Tensor> func = (x, y) =>
{
Tensor output = tf.constant(1.0f);
foreach (var i in range(y))
{
if (i > 1)
output = tf.multiply(output, x);
}
return output;
};

Func<Tensor, int, Tensor> grad = (x, y) =>
{
using var tape = tf.GradientTape();
tape.watch(x);
var output = func(x, y);
var grad = tape.gradient(output, x);
return grad;
};

var x = tf.constant(2.0f);
var result = grad(x, 4);
Assert.AreEqual((float)result, 4.0f);
}
}
}

Loading…
Cancel
Save