|
|
@@ -5,19 +5,25 @@ using System.Text; |
|
|
|
using Tensorflow.Common.Types; |
|
|
|
using Tensorflow.Keras.Layers.Rnn; |
|
|
|
using Tensorflow.Common.Extensions; |
|
|
|
using System.Linq; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Utils |
|
|
|
{ |
|
|
|
internal static class RnnUtils |
|
|
|
{ |
|
|
|
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) |
|
|
|
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) |
|
|
|
{ |
|
|
|
Func<GeneralizedTensorShape, Tensor> create_zeros; |
|
|
|
create_zeros = (GeneralizedTensorShape unnested_state_size) => |
|
|
|
{ |
|
|
|
var flat_dims = unnested_state_size.ToSingleShape().dims; |
|
|
|
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); |
|
|
|
return array_ops.zeros(new Shape(init_state_size), dtype: dtype); |
|
|
|
var init_state_size = new List<object> { batch_size_tensor}; |
|
|
|
foreach(var dim in flat_dims) |
|
|
|
{ |
|
|
|
init_state_size.add(dim); |
|
|
|
} |
|
|
|
var init_state_size_tensor = ops.convert_to_tensor(init_state_size.ToArray()); |
|
|
|
return array_ops.zeros(init_state_size_tensor); |
|
|
|
}; |
|
|
|
|
|
|
|
// TODO(Rinne): map structure with nested tensors. |
|
|
@@ -34,12 +40,13 @@ namespace Tensorflow.Keras.Utils |
|
|
|
|
|
|
|
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) |
|
|
|
{ |
|
|
|
Tensor batch_size_tensor = tf.convert_to_tensor(batch_size); |
|
|
|
if (inputs != null) |
|
|
|
{ |
|
|
|
batch_size = inputs.shape[0]; |
|
|
|
batch_size_tensor = tf.shape(inputs)[0]; |
|
|
|
dtype = inputs.dtype; |
|
|
|
} |
|
|
|
return generate_zero_filled_state(batch_size, cell.StateSize, dtype); |
|
|
|
return generate_zero_filled_state(batch_size_tensor, cell.StateSize, dtype); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|