diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 6b9f073c..55f13920 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -428,9 +428,9 @@ namespace Tensorflow.Operations return x; var x_rank = array_ops.rank(x); - var con1 = new object[] + var con1 = new Tensor[] { - new []{1, 0 }, + new Tensor(new int[]{0, 2}), math_ops.range(2, x_rank) }; var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 57af3b83..1b424006 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -166,6 +166,11 @@ namespace Tensorflow throw new ValueError("mask cannot be scalar."); var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 })); + if (leading_size.rank == 0) + { + leading_size = expand_dims(leading_size, 0); + } + var shape1 = concat(new[] { shape(tensor_tensor)[$":{axis}"], @@ -185,7 +190,7 @@ namespace Tensorflow private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) { - var indices = squeeze(where(mask), axis: new[] { 1 }); + var indices = squeeze(where_v2(mask), axis: new[] { 1 }); return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis)); } @@ -940,12 +945,12 @@ namespace Tensorflow /// public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") { - return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); + return gen_array_ops.concat_v2(values, axis, name: name); } - public static Tensor concat(object[] values, int axis, string name = "concat") + public static Tensor concat(Tensor[] values, Axis axis, string name = "concat") { - return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); + return gen_array_ops.concat_v2(values, axis, name: name); } /// diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 00d7d316..394a591a 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -287,7 +287,7 @@ namespace Tensorflow new[] { math_ops.subtract(rank, 1) }, new[] { constant_op.constant(1) }); - var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0); + var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0); var output = array_ops.reshape(logits, ops); // Set output shape if known. diff --git a/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs index 90de7874..8093c1f2 100644 --- a/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs @@ -3,6 +3,7 @@ using Tensorflow.NumPy; using System; using System.Linq; using static Tensorflow.Binding; +using Tensorflow; namespace TensorFlowNET.UnitTest.Basics { @@ -60,14 +61,14 @@ namespace TensorFlowNET.UnitTest.Basics Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray())); } - [TestMethod, Ignore] + [TestMethod] public void boolean_mask() { + if (!tf.executing_eagerly()) + tf.enable_eager_execution(); var tensor = new[] { 0, 1, 2, 3 }; var mask = np.array(new[] { true, false, true, false }); var masked = tf.boolean_mask(tensor, mask); - var sess = tf.Session(); - var result = sess.run(masked); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray())); } }