Browse Source

Merge pull request #1208 from Wanglongzhi2001/fix_concat_v2_bug

fix: fix the bug caused by concat_v2
tags/v0.150.0-BERT-Model
Haiping GitHub 1 year ago
parent
commit
6aa2ab2afe
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 6 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_ops.cs

+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -428,9 +428,9 @@ namespace Tensorflow.Operations
return x; return x;


var x_rank = array_ops.rank(x); var x_rank = array_ops.rank(x);
var con1 = new Tensor[]
var con1 = new object[]
{ {
new Tensor(new int[]{0, 2}),
new []{1, 0 },
math_ops.range(2, x_rank) math_ops.range(2, x_rank)
}; };
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));


+ 3
- 3
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -945,12 +945,12 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{ {
return gen_array_ops.concat_v2(values, axis, name: name);
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
} }


public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
public static Tensor concat(object[] values, int axis, string name = "concat")
{ {
return gen_array_ops.concat_v2(values, axis, name: name);
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
} }


/// <summary> /// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -287,7 +287,7 @@ namespace Tensorflow
new[] { math_ops.subtract(rank, 1) }, new[] { math_ops.subtract(rank, 1) },
new[] { constant_op.constant(1) }); new[] { constant_op.constant(1) });


var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
var output = array_ops.reshape(logits, ops); var output = array_ops.reshape(logits, ops);


// Set output shape if known. // Set output shape if known.


Loading…
Cancel
Save