|
|
@@ -58,10 +58,7 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
|
|
|
|
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) |
|
|
|
{ |
|
|
|
Console.WriteLine($"shape of input: {inputs.shape}"); |
|
|
|
Tensor states = initial_state[0]; |
|
|
|
Console.WriteLine($"shape of initial_state: {states.shape}"); |
|
|
|
|
|
|
|
var prev_output = nest.is_nested(states) ? states[0] : states; |
|
|
|
var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); |
|
|
|
var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); |
|
|
@@ -72,11 +69,12 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
{ |
|
|
|
if (ranks > 2) |
|
|
|
{ |
|
|
|
h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); |
|
|
|
// 因为multiply函数会自动添加第一个维度,所以加上下标0 |
|
|
|
h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor()); |
|
|
|
h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor()); |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
@@ -98,22 +96,18 @@ namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
|
|
|
|
if (rec_dp_mask != null) |
|
|
|
{ |
|
|
|
prev_output = tf.multiply(prev_output, rec_dp_mask); |
|
|
|
prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
ranks = prev_output.rank; |
|
|
|
Console.WriteLine($"shape of h: {h.shape}"); |
|
|
|
|
|
|
|
Tensor output; |
|
|
|
if (ranks > 2) |
|
|
|
{ |
|
|
|
var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); |
|
|
|
output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0]; |
|
|
|
output = h + tf.linalg.tensordot(prev_output[0], recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0]; |
|
|
|
|
|
|
|
output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor()); |
|
|
|
} |
|
|
|
Console.WriteLine($"shape of output: {output.shape}"); |
|
|
|
|
|
|
|