You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RNN.cs 21 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. using OneOf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Reflection;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine;
  7. using Tensorflow.Keras.Saving;
  8. using Tensorflow.Util;
  9. using Tensorflow.Common.Extensions;
  10. using System.Linq.Expressions;
  11. using Tensorflow.Keras.Utils;
  12. using Tensorflow.Common.Types;
  13. using System.Runtime.CompilerServices;
  14. // from tensorflow.python.distribute import distribution_strategy_context as ds_context;
  15. namespace Tensorflow.Keras.Layers
  16. {
  17. /// <summary>
  18. /// Base class for recurrent layers.
  19. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  20. /// for details about the usage of RNN API.
  21. /// </summary>
  22. public class RNN : RnnBase
  23. {
  24. private RNNArgs _args;
  25. private object _input_spec = null; // or NoneValue??
  26. private object _state_spec = null;
  27. private Tensors _states = null;
  28. private object _constants_spec = null;
  29. private int _num_constants;
  30. protected IVariableV1 _kernel;
  31. protected IVariableV1 _bias;
  32. private IRnnCell _cell;
  33. public RNNArgs Args { get => _args; }
  34. public IRnnCell Cell
  35. {
  36. get
  37. {
  38. return _cell;
  39. }
  40. init
  41. {
  42. _cell = value;
  43. _self_tracked_trackables.Add(_cell);
  44. }
  45. }
  46. public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args))
  47. {
  48. _args = args;
  49. SupportsMasking = true;
  50. Cell = cell;
  51. // get input_shape
  52. _args = PreConstruct(args);
  53. _num_constants = 0;
  54. }
  55. public RNN(IEnumerable<IRnnCell> cells, RNNArgs args) : base(PreConstruct(args))
  56. {
  57. _args = args;
  58. SupportsMasking = true;
  59. Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs());
  60. // get input_shape
  61. _args = PreConstruct(args);
  62. _num_constants = 0;
  63. }
  64. // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
  65. // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
  66. public Tensors States
  67. {
  68. get
  69. {
  70. if (_states == null)
  71. {
  72. // CHECK(Rinne): check if this is correct.
  73. var nested = Cell.StateSize.MapStructure<Tensor?>(x => null);
  74. _states = nested.AsNest().ToTensors();
  75. }
  76. return _states;
  77. }
  78. set { _states = value; }
  79. }
  80. private INestStructure<Shape> compute_output_shape(Shape input_shape)
  81. {
  82. var batch = input_shape[0];
  83. var time_step = input_shape[1];
  84. if (_args.TimeMajor)
  85. {
  86. (batch, time_step) = (time_step, batch);
  87. }
  88. // state_size is a array of ints or a positive integer
  89. var state_size = Cell.StateSize;
  90. if(state_size?.TotalNestedCount == 1)
  91. {
  92. state_size = new NestList<long>(state_size.Flatten().First());
  93. }
  94. Func<long, Shape> _get_output_shape = (flat_output_size) =>
  95. {
  96. var output_dim = new Shape(flat_output_size).as_int_list();
  97. Shape output_shape;
  98. if (_args.ReturnSequences)
  99. {
  100. if (_args.TimeMajor)
  101. {
  102. output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
  103. }
  104. else
  105. {
  106. output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
  107. }
  108. }
  109. else
  110. {
  111. output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
  112. }
  113. return output_shape;
  114. };
  115. Type type = Cell.GetType();
  116. PropertyInfo output_size_info = type.GetProperty("output_size");
  117. INestStructure<Shape> output_shape;
  118. if (output_size_info != null)
  119. {
  120. output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize);
  121. }
  122. else
  123. {
  124. output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First()));
  125. }
  126. if (_args.ReturnState)
  127. {
  128. Func<long, Shape> _get_state_shape = (flat_state) =>
  129. {
  130. var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list());
  131. return new Shape(state_shape);
  132. };
  133. var state_shape = Nest.MapStructure(_get_state_shape, state_size);
  134. return new Nest<Shape>(new[] { output_shape, state_shape } );
  135. }
  136. else
  137. {
  138. return output_shape;
  139. }
  140. }
  141. private Tensors compute_mask(Tensors inputs, Tensors mask)
  142. {
  143. // Time step masks must be the same for each input.
  144. // This is because the mask for an RNN is of size [batch, time_steps, 1],
  145. // and specifies which time steps should be skipped, and a time step
  146. // must be skipped for all inputs.
  147. mask = nest.flatten(mask)[0];
  148. var output_mask = _args.ReturnSequences ? mask : null;
  149. if (_args.ReturnState)
  150. {
  151. var state_mask = new List<Tensor>();
  152. for (int i = 0; i < len(States); i++)
  153. {
  154. state_mask.Add(null);
  155. }
  156. return new List<Tensor> { output_mask }.concat(state_mask);
  157. }
  158. else
  159. {
  160. return output_mask;
  161. }
  162. }
  163. public override void build(KerasShapesWrapper input_shape)
  164. {
  165. _buildInputShape = input_shape;
  166. input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);
  167. InputSpec get_input_spec(Shape shape)
  168. {
  169. var input_spec_shape = shape.as_int_list();
  170. var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1);
  171. if (!_args.Stateful)
  172. {
  173. input_spec_shape[batch_index] = -1;
  174. }
  175. input_spec_shape[time_step_index] = -1;
  176. return new InputSpec(shape: input_spec_shape);
  177. }
  178. Shape get_step_input_shape(Shape shape)
  179. {
  180. // return shape[1:] if self.time_major else (shape[0],) + shape[2:]
  181. if (_args.TimeMajor)
  182. {
  183. return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
  184. }
  185. else
  186. {
  187. return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
  188. }
  189. }
  190. object get_state_spec(Shape shape)
  191. {
  192. var state_spec_shape = shape.as_int_list();
  193. // append bacth dim
  194. state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
  195. return new InputSpec(shape: state_spec_shape);
  196. }
  197. // Check whether the input shape contains any nested shapes. It could be
  198. // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
  199. // numpy inputs.
  200. if (Cell is Layer layer && !layer.Built)
  201. {
  202. layer.build(input_shape);
  203. layer.Built = true;
  204. }
  205. this.built = true;
  206. }
  207. /// <summary>
  208. ///
  209. /// </summary>
  210. /// <param name="inputs"></param>
  211. /// <param name="initial_state">List of initial state tensors to be passed to the first call of the cell</param>
  212. /// <param name="training"></param>
  213. /// <param name="optional_args"></param>
  214. /// <returns></returns>
  215. /// <exception cref="ValueError"></exception>
  216. /// <exception cref="NotImplementedException"></exception>
  217. protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
  218. {
  219. RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
  220. if(optional_args is not null && rnn_optional_args is null)
  221. {
  222. throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`");
  223. }
  224. Tensors? constants = rnn_optional_args?.Constants;
  225. Tensors? mask = rnn_optional_args?.Mask;
  226. //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
  227. // 暂时先不接受ragged tensor
  228. int row_length = 0; // TODO(Rinne): support this param.
  229. bool is_ragged_input = false;
  230. _validate_args_if_ragged(is_ragged_input, mask);
  231. (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
  232. _maybe_reset_cell_dropout_mask(Cell);
  233. if (Cell is StackedRNNCells)
  234. {
  235. var stack_cell = Cell as StackedRNNCells;
  236. foreach (IRnnCell cell in stack_cell.Cells)
  237. {
  238. _maybe_reset_cell_dropout_mask(cell);
  239. }
  240. }
  241. if (mask != null)
  242. {
  243. // Time step masks must be the same for each input.
  244. mask = mask.Flatten().First();
  245. }
  246. Shape input_shape;
  247. if (!inputs.IsNested())
  248. {
  249. // In the case of nested input, use the first element for shape check
  250. // input_shape = nest.flatten(inputs)[0].shape;
  251. // TODO(Wanglongzhi2001)
  252. input_shape = inputs.Flatten().First().shape;
  253. }
  254. else
  255. {
  256. input_shape = inputs.shape;
  257. }
  258. var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
  259. if (_args.Unroll && timesteps == null)
  260. {
  261. throw new ValueError(
  262. "Cannot unroll a RNN if the " +
  263. "time dimension is undefined. \n" +
  264. "- If using a Sequential model, " +
  265. "specify the time dimension by passing " +
  266. "an `input_shape` or `batch_input_shape` " +
  267. "argument to your first layer. If your " +
  268. "first layer is an Embedding, you can " +
  269. "also use the `input_length` argument.\n" +
  270. "- If using the functional API, specify " +
  271. "the time dimension by passing a `shape` " +
  272. "or `batch_shape` argument to your Input layer."
  273. );
  274. }
  275. // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
  276. Func<Tensors, Tensors, (Tensors, Tensors)> step;
  277. bool is_tf_rnn_cell = false;
  278. if (constants is not null)
  279. {
  280. if (!Cell.SupportOptionalArgs)
  281. {
  282. throw new ValueError(
  283. $"RNN cell {Cell} does not support constants." +
  284. $"Received: constants={constants}");
  285. }
  286. step = (inputs, states) =>
  287. {
  288. constants = new Tensors(states.TakeLast(_num_constants).ToArray());
  289. states = new Tensors(states.SkipLast(_num_constants).ToArray());
  290. states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
  291. var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
  292. return (output, new_states);
  293. };
  294. }
  295. else
  296. {
  297. step = (inputs, states) =>
  298. {
  299. states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
  300. var (output, new_states) = Cell.Apply(inputs, states);
  301. return (output, new_states);
  302. };
  303. }
  304. var (last_output, outputs, states) = keras.backend.rnn(
  305. step,
  306. inputs,
  307. initial_state,
  308. constants: constants,
  309. go_backwards: _args.GoBackwards,
  310. mask: mask,
  311. unroll: _args.Unroll,
  312. input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
  313. time_major: _args.TimeMajor,
  314. zero_output_for_mask: _args.ZeroOutputForMask,
  315. return_all_outputs: _args.ReturnSequences);
  316. if (_args.Stateful)
  317. {
  318. throw new NotImplementedException("this argument havn't been developed.");
  319. }
  320. Tensors output = new Tensors();
  321. if (_args.ReturnSequences)
  322. {
  323. // TODO(Rinne): add go_backwards parameter and revise the `row_length` param
  324. output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
  325. }
  326. else
  327. {
  328. output = last_output;
  329. }
  330. if (_args.ReturnState)
  331. {
  332. foreach (var state in states)
  333. {
  334. output.Add(state);
  335. }
  336. return output;
  337. }
  338. else
  339. {
  340. //var tapeSet = tf.GetTapeSet();
  341. //foreach(var tape in tapeSet)
  342. //{
  343. // tape.Watch(output);
  344. //}
  345. return output;
  346. }
  347. }
  348. public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null)
  349. {
  350. RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
  351. if (optional_args is not null && rnn_optional_args is null)
  352. {
  353. throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`.");
  354. }
  355. Tensors? constants = rnn_optional_args?.Constants;
  356. (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants);
  357. if(initial_states is null && constants is null)
  358. {
  359. return base.Apply(inputs);
  360. }
  361. // TODO(Rinne): implement it.
  362. throw new NotImplementedException();
  363. }
  364. protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
  365. {
  366. if (inputs.Length > 1)
  367. {
  368. if (_num_constants != 0)
  369. {
  370. initial_state = new Tensors(inputs.Skip(1).ToArray());
  371. }
  372. else
  373. {
  374. initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray());
  375. constants = new Tensors(inputs.TakeLast(_num_constants).ToArray());
  376. }
  377. if (len(initial_state) == 0)
  378. initial_state = null;
  379. inputs = inputs[0];
  380. }
  381. if (_args.Stateful)
  382. {
  383. if (initial_state != null)
  384. {
  385. var tmp = new Tensor[] { };
  386. foreach (var s in nest.flatten(States))
  387. {
  388. tmp.add(tf.math.count_nonzero(s.Single()));
  389. }
  390. var non_zero_count = tf.add_n(tmp);
  391. initial_state = tf.cond(non_zero_count > 0, States, initial_state);
  392. if ((int)non_zero_count.numpy() > 0)
  393. {
  394. initial_state = States;
  395. }
  396. }
  397. else
  398. {
  399. initial_state = States;
  400. }
  401. //initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
  402. }
  403. else if (initial_state is null)
  404. {
  405. initial_state = get_initial_state(inputs);
  406. }
  407. if (initial_state.Length != States.Length)
  408. {
  409. throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
  410. $"but it received {initial_state.Length} " +
  411. $"initial state(s). Input received: {inputs}");
  412. }
  413. return (inputs, initial_state, constants);
  414. }
  415. private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
  416. {
  417. if (!is_ragged_input)
  418. {
  419. return;
  420. }
  421. if (_args.Unroll)
  422. {
  423. throw new ValueError("The input received contains RaggedTensors and does " +
  424. "not support unrolling. Disable unrolling by passing " +
  425. "`unroll=False` in the RNN Layer constructor.");
  426. }
  427. if (mask != null)
  428. {
  429. throw new ValueError($"The mask that was passed in was {mask}, which " +
  430. "cannot be applied to RaggedTensor inputs. Please " +
  431. "make sure that there is no mask injected by upstream " +
  432. "layers.");
  433. }
  434. }
  435. protected void _maybe_reset_cell_dropout_mask(ILayer cell)
  436. {
  437. if (cell is DropoutRNNCellMixin CellDRCMixin)
  438. {
  439. CellDRCMixin.reset_dropout_mask();
  440. CellDRCMixin.reset_recurrent_dropout_mask();
  441. }
  442. }
  443. private static RNNArgs PreConstruct(RNNArgs args)
  444. {
  445. // If true, the output for masked timestep will be zeros, whereas in the
  446. // false case, output from previous timestep is returned for masked timestep.
  447. var zeroOutputForMask = args.ZeroOutputForMask;
  448. Shape input_shape;
  449. var propIS = args.InputShape;
  450. var propID = args.InputDim;
  451. var propIL = args.InputLength;
  452. if (propIS == null && (propID != null || propIL != null))
  453. {
  454. input_shape = new Shape(
  455. propIL ?? -1,
  456. propID ?? -1);
  457. args.InputShape = input_shape;
  458. }
  459. return args;
  460. }
  461. public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
  462. {
  463. throw new NotImplementedException();
  464. }
  465. // 好像不能cell不能传接口类型
  466. //public RNN New(IRnnArgCell cell,
  467. // bool return_sequences = false,
  468. // bool return_state = false,
  469. // bool go_backwards = false,
  470. // bool stateful = false,
  471. // bool unroll = false,
  472. // bool time_major = false)
  473. // => new RNN(new RNNArgs
  474. // {
  475. // Cell = cell,
  476. // ReturnSequences = return_sequences,
  477. // ReturnState = return_state,
  478. // GoBackwards = go_backwards,
  479. // Stateful = stateful,
  480. // Unroll = unroll,
  481. // TimeMajor = time_major
  482. // });
  483. //public RNN New(List<IRnnArgCell> cell,
  484. // bool return_sequences = false,
  485. // bool return_state = false,
  486. // bool go_backwards = false,
  487. // bool stateful = false,
  488. // bool unroll = false,
  489. // bool time_major = false)
  490. // => new RNN(new RNNArgs
  491. // {
  492. // Cell = cell,
  493. // ReturnSequences = return_sequences,
  494. // ReturnState = return_state,
  495. // GoBackwards = go_backwards,
  496. // Stateful = stateful,
  497. // Unroll = unroll,
  498. // TimeMajor = time_major
  499. // });
  500. protected Tensors get_initial_state(Tensors inputs)
  501. {
  502. var input = inputs[0];
  503. var input_shape = array_ops.shape(inputs);
  504. var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
  505. var dtype = input.dtype;
  506. Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);
  507. return init_state;
  508. }
  509. public override IKerasConfig get_config()
  510. {
  511. return _args;
  512. }
  513. }
  514. }