Browse Source

Fix _SliceGrad #800

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
a653914321
3 changed files with 8 additions and 8 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  3. +6
    -6
      test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -258,7 +258,7 @@ namespace Tensorflow.Gradients
var input_rank = array_ops.rank(input_vec);
var slice_size = array_ops.shape(op.outputs[0]);

var shape = array_ops.stack(new Tensor[] { input_rank, new Tensor(1) });
var shape = array_ops.stack(new Tensor[] { input_rank, ops.convert_to_tensor(1) });
var before_pad = array_ops.reshape(begin_vec, shape);
var after_pad = array_ops.reshape(array_ops.shape(input_vec) - slice_size - begin_vec, shape);
var paddings = array_ops.concat(new Tensor[] { before_pad, after_pad }, 1);


+ 1
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -441,7 +441,7 @@ namespace Tensorflow
if (tf.Context.executing_eagerly())
{
var input_shape_val = input_shape.numpy();
foreach (var axes_val in axes.numpy().ToArray<int>())
foreach (var axes_val in axes.ToArray<int>())
input_shape_val[axes_val] = 1;
return tf.constant(input_shape_val);
}


+ 6
- 6
test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs View File

@@ -51,29 +51,29 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestMethod]
public void GradientSliceTest()
{
var X = tf.zeros(new TensorShape(10));
var X = tf.zeros(10);
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");
using var g = tf.GradientTape();
var pred = W * X + b;
var test = tf.slice(pred, new[] { 0 }, pred.shape);
var gradients = g.gradient(test, (W, b));
Assert.AreNotEqual(gradients.Item1, null);
Assert.AreNotEqual(gradients.Item2, null);
Assert.AreEqual((float)gradients.Item1, 0f);
Assert.AreEqual((float)gradients.Item2, 10f);
}

[TestMethod]
public void GradientConcatTest()
{
var X = tf.zeros(new TensorShape(10));
var X = tf.zeros(10);
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");
var test = tf.concat(new Tensor[] { W, b }, 0);
using var g = tf.GradientTape();
var pred = test[0] * X + test[1];
var gradients = g.gradient(pred, (W, b));
Assert.AreEqual((float)gradients.Item1, 0);
Assert.AreEqual((float)gradients.Item2, 10);
Assert.IsNull(gradients.Item1);
Assert.IsNull(gradients.Item2);
}
}
}

Loading…
Cancel
Save