You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RnnUtils.cs 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Text;
  5. using Tensorflow.Common.Types;
  6. using Tensorflow.Keras.Layers.Rnn;
  7. using Tensorflow.Common.Extensions;
  8. namespace Tensorflow.Keras.Utils
  9. {
  10. internal static class RnnUtils
  11. {
  12. internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
  13. {
  14. Func<GeneralizedTensorShape, Tensor> create_zeros;
  15. create_zeros = (GeneralizedTensorShape unnested_state_size) =>
  16. {
  17. var flat_dims = unnested_state_size.ToSingleShape().dims;
  18. var init_state_size = new Tensor[] { batch_size_tensor }.
  19. Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray();
  20. return array_ops.zeros(init_state_size, dtype: dtype);
  21. };
  22. // TODO(Rinne): map structure with nested tensors.
  23. if(state_size.TotalNestedCount > 1)
  24. {
  25. return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray());
  26. }
  27. else
  28. {
  29. return create_zeros(state_size);
  30. }
  31. }
  32. internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype)
  33. {
  34. if (inputs is not null)
  35. {
  36. batch_size = array_ops.shape(inputs)[0];
  37. dtype = inputs.dtype;
  38. }
  39. return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
  40. }
  41. /// <summary>
  42. /// Standardizes `__call__` to a single list of tensor inputs.
  43. ///
  44. /// When running a model loaded from a file, the input tensors
  45. /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part
  46. /// of `inputs` instead of by the dedicated keyword arguments.This method
  47. /// makes sure the arguments are separated and that `initial_state` and
  48. /// `constants` are lists of tensors(or None).
  49. /// </summary>
  50. /// <param name="inputs">Tensor or list/tuple of tensors. which may include constants
  51. /// and initial states.In that case `num_constant` must be specified.</param>
  52. /// <param name="initial_state">Tensor or list of tensors or None, initial states.</param>
  53. /// <param name="constants">Tensor or list of tensors or None, constant tensors.</param>
  54. /// <param name="num_constants">Expected number of constants (if constants are passed as
  55. /// part of the `inputs` list.</param>
  56. /// <returns></returns>
  57. internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants)
  58. {
  59. if(inputs.Length > 1)
  60. {
  61. // There are several situations here:
  62. // In the graph mode, __call__ will be only called once. The initial_state
  63. // and constants could be in inputs (from file loading).
  64. // In the eager mode, __call__ will be called twice, once during
  65. // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
  66. // model.fit/train_on_batch/predict with real np data. In the second case,
  67. // the inputs will contain initial_state and constants as eager tensor.
  68. //
  69. // For either case, the real input is the first item in the list, which
  70. // could be a nested structure itself. Then followed by initial_states, which
  71. // could be a list of items, or list of list if the initial_state is complex
  72. // structure, and finally followed by constants which is a flat list.
  73. Debug.Assert(initial_state is null && constants is null);
  74. if(num_constants > 0)
  75. {
  76. constants = inputs.TakeLast(num_constants).ToArray().ToTensors();
  77. inputs = inputs.SkipLast(num_constants).ToArray().ToTensors();
  78. }
  79. if(inputs.Length > 1)
  80. {
  81. initial_state = inputs.Skip(1).ToArray().ToTensors();
  82. inputs = inputs.Take(1).ToArray().ToTensors();
  83. }
  84. }
  85. return (inputs, initial_state, constants);
  86. }
  87. /// <summary>
  88. /// Check whether the state_size contains multiple states.
  89. /// </summary>
  90. /// <param name="state_size"></param>
  91. /// <returns></returns>
  92. public static bool is_multiple_state(GeneralizedTensorShape state_size)
  93. {
  94. return state_size.TotalNestedCount > 1;
  95. }
  96. }
  97. }