You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RNN.cs 22 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Reflection;
  5. using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;
  6. using Tensorflow.Keras.ArgsDefinition.Rnn;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Saving;
  9. using Tensorflow.Util;
  10. using OneOf;
  11. using OneOf.Types;
  12. using Tensorflow.Common.Extensions;
  13. // from tensorflow.python.distribute import distribution_strategy_context as ds_context;
  14. namespace Tensorflow.Keras.Layers.Rnn
  15. {
  16. public class RNN : Layer
  17. {
  18. private RNNArgs args;
  19. private object input_spec = null; // or NoneValue??
  20. private object state_spec = null;
  21. private Tensors _states = null;
  22. private object constants_spec = null;
  23. private int _num_constants = 0;
  24. protected IVariableV1 kernel;
  25. protected IVariableV1 bias;
  26. protected ILayer cell;
  27. public RNN(RNNArgs args) : base(PreConstruct(args))
  28. {
  29. this.args = args;
  30. SupportsMasking = true;
  31. // if is StackedRnncell
  32. if (args.Cell.IsT0)
  33. {
  34. cell = new StackedRNNCells(new StackedRNNCellsArgs
  35. {
  36. Cells = args.Cell.AsT0,
  37. });
  38. }
  39. else
  40. {
  41. cell = args.Cell.AsT1;
  42. }
  43. Type type = cell.GetType();
  44. MethodInfo callMethodInfo = type.GetMethod("Call");
  45. if (callMethodInfo == null)
  46. {
  47. throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. ");
  48. }
  49. PropertyInfo state_size_info = type.GetProperty("state_size");
  50. if (state_size_info == null)
  51. {
  52. throw new ValueError(@"The RNN cell should have a `state_size` attribute");
  53. }
  54. // get input_shape
  55. this.args = PreConstruct(args);
  56. // The input shape is unknown yet, it could have nested tensor inputs, and
  57. // the input spec will be the list of specs for nested inputs, the structure
  58. // of the input_spec will be the same as the input.
  59. //if(stateful)
  60. //{
  61. // if (ds_context.has_strategy()) // ds_context????
  62. // {
  63. // throw new Exception("RNNs with stateful=True not yet supported with tf.distribute.Strategy");
  64. // }
  65. //}
  66. }
  67. // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
  68. // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
  69. public Tensors States
  70. {
  71. get
  72. {
  73. if (_states == null)
  74. {
  75. var state = nest.map_structure(x => null, cell.state_size);
  76. return nest.is_nested(state) ? state : new Tensors { state };
  77. }
  78. return _states;
  79. }
  80. set { _states = value; }
  81. }
  82. private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
  83. {
  84. var batch = input_shape[0];
  85. var time_step = input_shape[1];
  86. if (args.TimeMajor)
  87. {
  88. (batch, time_step) = (time_step, batch);
  89. }
  90. // state_size is a array of ints or a positive integer
  91. var state_size = cell.state_size;
  92. // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
  93. Func<Shape, Shape> _get_output_shape;
  94. _get_output_shape = (flat_output_size) =>
  95. {
  96. var output_dim = flat_output_size.as_int_list();
  97. Shape output_shape;
  98. if (args.ReturnSequences)
  99. {
  100. if (args.TimeMajor)
  101. {
  102. output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
  103. }
  104. else
  105. {
  106. output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
  107. }
  108. }
  109. else
  110. {
  111. output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
  112. }
  113. return output_shape;
  114. };
  115. Type type = cell.GetType();
  116. PropertyInfo output_size_info = type.GetProperty("output_size");
  117. Shape output_shape;
  118. if (output_size_info != null)
  119. {
  120. output_shape = nest.map_structure(_get_output_shape, cell.output_size);
  121. // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
  122. output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
  123. }
  124. else
  125. {
  126. output_shape = _get_output_shape(state_size[0]);
  127. }
  128. if (args.ReturnState)
  129. {
  130. Func<Shape, Shape> _get_state_shape;
  131. _get_state_shape = (flat_state) =>
  132. {
  133. var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
  134. return new Shape(state_shape);
  135. };
  136. var state_shape = _get_state_shape(new Shape(state_size.ToArray()));
  137. return new List<Shape> { output_shape, state_shape };
  138. }
  139. else
  140. {
  141. return output_shape;
  142. }
  143. }
  144. private Tensors compute_mask(Tensors inputs, Tensors mask)
  145. {
  146. // Time step masks must be the same for each input.
  147. // This is because the mask for an RNN is of size [batch, time_steps, 1],
  148. // and specifies which time steps should be skipped, and a time step
  149. // must be skipped for all inputs.
  150. mask = nest.flatten(mask)[0];
  151. var output_mask = args.ReturnSequences ? mask : null;
  152. if (args.ReturnState)
  153. {
  154. var state_mask = new List<Tensor>();
  155. for (int i = 0; i < len(States); i++)
  156. {
  157. state_mask.Add(null);
  158. }
  159. return new List<Tensor> { output_mask }.concat(state_mask);
  160. }
  161. else
  162. {
  163. return output_mask;
  164. }
  165. }
  166. public override void build(KerasShapesWrapper input_shape)
  167. {
  168. object get_input_spec(Shape shape)
  169. {
  170. var input_spec_shape = shape.as_int_list();
  171. var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1);
  172. if (!args.Stateful)
  173. {
  174. input_spec_shape[batch_index] = -1;
  175. }
  176. input_spec_shape[time_step_index] = -1;
  177. return new InputSpec(shape: input_spec_shape);
  178. }
  179. Shape get_step_input_shape(Shape shape)
  180. {
  181. // return shape[1:] if self.time_major else (shape[0],) + shape[2:]
  182. if (args.TimeMajor)
  183. {
  184. return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
  185. }
  186. else
  187. {
  188. return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
  189. }
  190. }
  191. object get_state_spec(Shape shape)
  192. {
  193. var state_spec_shape = shape.as_int_list();
  194. // append bacth dim
  195. state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
  196. return new InputSpec(shape: state_spec_shape);
  197. }
  198. // Check whether the input shape contains any nested shapes. It could be
  199. // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
  200. // numpy inputs.
  201. if (!cell.Built)
  202. {
  203. cell.build(input_shape);
  204. }
  205. }
  206. // inputs: Tensors
  207. // mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked
  208. // training: bool
  209. // initial_state: List of initial state tensors to be passed to the first call of the cell
  210. // constants: List of constant tensors to be passed to the cell at each timestep
  211. protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
  212. {
  213. //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
  214. // 暂时先不接受ragged tensor
  215. int? row_length = null;
  216. bool is_ragged_input = false;
  217. _validate_args_if_ragged(is_ragged_input, mask);
  218. (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
  219. _maybe_reset_cell_dropout_mask(cell);
  220. if (cell is StackedRNNCells)
  221. {
  222. var stack_cell = cell as StackedRNNCells;
  223. foreach (var cell in stack_cell.Cells)
  224. {
  225. _maybe_reset_cell_dropout_mask(cell);
  226. }
  227. }
  228. if (mask != null)
  229. {
  230. // Time step masks must be the same for each input.
  231. mask = nest.flatten(mask)[0];
  232. }
  233. Shape input_shape;
  234. if (nest.is_nested(inputs))
  235. {
  236. // In the case of nested input, use the first element for shape check
  237. // input_shape = nest.flatten(inputs)[0].shape;
  238. // TODO(Wanglongzhi2001)
  239. input_shape = nest.flatten(inputs)[0].shape;
  240. }
  241. else
  242. {
  243. input_shape = inputs.shape;
  244. }
  245. var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];
  246. if (args.Unroll && timesteps != null)
  247. {
  248. throw new ValueError(
  249. "Cannot unroll a RNN if the " +
  250. "time dimension is undefined. \n" +
  251. "- If using a Sequential model, " +
  252. "specify the time dimension by passing " +
  253. "an `input_shape` or `batch_input_shape` " +
  254. "argument to your first layer. If your " +
  255. "first layer is an Embedding, you can " +
  256. "also use the `input_length` argument.\n" +
  257. "- If using the functional API, specify " +
  258. "the time dimension by passing a `shape` " +
  259. "or `batch_shape` argument to your Input layer."
  260. );
  261. }
  262. // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
  263. var cell_call_fn = cell.Call;
  264. Func<Tensors, Tensors, (Tensors, Tensors)> step;
  265. if (constants != null)
  266. {
  267. ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters();
  268. bool hasParam = parameters.Any(p => p.Name == "constants");
  269. if (!hasParam)
  270. {
  271. throw new ValueError(
  272. $"RNN cell {cell} does not support constants." +
  273. $"Received: constants={constants}");
  274. }
  275. step = (inputs, states) =>
  276. {
  277. // constants = states[-self._num_constants :]
  278. constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)];
  279. // states = states[: -self._num_constants]
  280. states = states.numpy()[new Slice(0, states.Length - _num_constants)];
  281. // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
  282. states = states.Length == 1 ? states[0] : states;
  283. var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
  284. // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
  285. if (!nest.is_nested(new_states))
  286. {
  287. return (output, new Tensors { new_states });
  288. }
  289. return (output, new_states);
  290. };
  291. }
  292. else
  293. {
  294. step = (inputs, states) =>
  295. {
  296. // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
  297. states = states.Length == 1 ? states[0] : states;
  298. var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
  299. if (!nest.is_nested(new_states))
  300. {
  301. return (output, new Tensors { new_states });
  302. }
  303. return (output, new_states);
  304. };
  305. }
  306. var (last_output, outputs, states) = BackendImpl.rnn(step,
  307. inputs,
  308. initial_state,
  309. constants: constants,
  310. go_backwards: args.GoBackwards,
  311. mask: mask,
  312. unroll: args.Unroll,
  313. input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
  314. time_major: args.TimeMajor,
  315. zero_output_for_mask: args.ZeroOutputForMask,
  316. return_all_outputs: args.ReturnSequences);
  317. if (args.Stateful)
  318. {
  319. throw new NotImplementedException("this argument havn't been developed!");
  320. }
  321. Tensors output = new Tensors();
  322. if (args.ReturnSequences)
  323. {
  324. throw new NotImplementedException("this argument havn't been developed!");
  325. }
  326. else
  327. {
  328. output = last_output;
  329. }
  330. if (args.ReturnState)
  331. {
  332. foreach (var state in states)
  333. {
  334. output.Add(state);
  335. }
  336. return output;
  337. }
  338. else
  339. {
  340. return output;
  341. }
  342. }
  343. private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants)
  344. {
  345. if (nest.is_sequence(input))
  346. {
  347. if (_num_constants != 0)
  348. {
  349. initial_state = inputs[new Slice(1, len(inputs))];
  350. }
  351. else
  352. {
  353. initial_state = inputs[new Slice(1, len(inputs) - _num_constants)];
  354. constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))];
  355. }
  356. if (len(initial_state) == 0)
  357. initial_state = null;
  358. inputs = inputs[0];
  359. }
  360. if (args.Stateful)
  361. {
  362. if (initial_state != null)
  363. {
  364. var tmp = new Tensor[] { };
  365. foreach (var s in nest.flatten(States))
  366. {
  367. tmp.add(tf.math.count_nonzero((Tensor)s));
  368. }
  369. var non_zero_count = tf.add_n(tmp);
  370. //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
  371. if((int)non_zero_count.numpy() > 0)
  372. {
  373. initial_state = States;
  374. }
  375. }
  376. else
  377. {
  378. initial_state = States;
  379. }
  380. }
  381. else if(initial_state != null)
  382. {
  383. initial_state = get_initial_state(inputs);
  384. }
  385. if (initial_state.Length != States.Length)
  386. {
  387. throw new ValueError(
  388. $"Layer {this} expects {States.Length} state(s), " +
  389. $"but it received {initial_state.Length} " +
  390. $"initial state(s). Input received: {inputs}");
  391. }
  392. return (inputs, initial_state, constants);
  393. }
  394. private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
  395. {
  396. if (!is_ragged_input)
  397. {
  398. return;
  399. }
  400. if (args.Unroll)
  401. {
  402. throw new ValueError("The input received contains RaggedTensors and does " +
  403. "not support unrolling. Disable unrolling by passing " +
  404. "`unroll=False` in the RNN Layer constructor.");
  405. }
  406. if (mask != null)
  407. {
  408. throw new ValueError($"The mask that was passed in was {mask}, which " +
  409. "cannot be applied to RaggedTensor inputs. Please " +
  410. "make sure that there is no mask injected by upstream " +
  411. "layers.");
  412. }
  413. }
  414. void _maybe_reset_cell_dropout_mask(ILayer cell)
  415. {
  416. //if (cell is DropoutRNNCellMixin)
  417. //{
  418. // cell.reset_dropout_mask();
  419. // cell.reset_recurrent_dropout_mask();
  420. //}
  421. }
  422. private static RNNArgs PreConstruct(RNNArgs args)
  423. {
  424. if (args.Kwargs == null)
  425. {
  426. args.Kwargs = new Dictionary<string, object>();
  427. }
  428. // If true, the output for masked timestep will be zeros, whereas in the
  429. // false case, output from previous timestep is returned for masked timestep.
  430. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
  431. Shape input_shape;
  432. var propIS = (Shape)args.Kwargs.Get("input_shape", null);
  433. var propID = (int?)args.Kwargs.Get("input_dim", null);
  434. var propIL = (int?)args.Kwargs.Get("input_length", null);
  435. if (propIS == null && (propID != null || propIL != null))
  436. {
  437. input_shape = new Shape(
  438. propIL ?? -1,
  439. propID ?? -1);
  440. args.Kwargs["input_shape"] = input_shape;
  441. }
  442. return args;
  443. }
  444. public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
  445. {
  446. throw new NotImplementedException();
  447. }
  448. // 好像不能cell不能传接口类型
  449. //public RNN New(IRnnArgCell cell,
  450. // bool return_sequences = false,
  451. // bool return_state = false,
  452. // bool go_backwards = false,
  453. // bool stateful = false,
  454. // bool unroll = false,
  455. // bool time_major = false)
  456. // => new RNN(new RNNArgs
  457. // {
  458. // Cell = cell,
  459. // ReturnSequences = return_sequences,
  460. // ReturnState = return_state,
  461. // GoBackwards = go_backwards,
  462. // Stateful = stateful,
  463. // Unroll = unroll,
  464. // TimeMajor = time_major
  465. // });
  466. //public RNN New(List<IRnnArgCell> cell,
  467. // bool return_sequences = false,
  468. // bool return_state = false,
  469. // bool go_backwards = false,
  470. // bool stateful = false,
  471. // bool unroll = false,
  472. // bool time_major = false)
  473. // => new RNN(new RNNArgs
  474. // {
  475. // Cell = cell,
  476. // ReturnSequences = return_sequences,
  477. // ReturnState = return_state,
  478. // GoBackwards = go_backwards,
  479. // Stateful = stateful,
  480. // Unroll = unroll,
  481. // TimeMajor = time_major
  482. // });
  483. protected Tensors get_initial_state(Tensor inputs)
  484. {
  485. Type type = cell.GetType();
  486. MethodInfo MethodInfo = type.GetMethod("get_initial_state");
  487. if (nest.is_nested(inputs))
  488. {
  489. // The input are nested sequences. Use the first element in the seq
  490. // to get batch size and dtype.
  491. inputs = nest.flatten(inputs)[0];
  492. }
  493. var input_shape = tf.shape(inputs);
  494. var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0];
  495. var dtype = inputs.dtype;
  496. Tensor init_state;
  497. if (MethodInfo != null)
  498. {
  499. init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype });
  500. }
  501. else
  502. {
  503. init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype);
  504. }
  505. //if (!nest.is_nested(init_state))
  506. //{
  507. // init_state = new List<Tensor> { init_state};
  508. //}
  509. return new List<Tensor> { init_state };
  510. //return _generate_zero_filled_state_for_cell(null, null);
  511. }
  512. Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size)
  513. {
  514. throw new NotImplementedException("");
  515. }
  516. // Check whether the state_size contains multiple states.
  517. public static bool is_multiple_state(object state_size)
  518. {
  519. var myIndexerProperty = state_size.GetType().GetProperty("Item");
  520. return myIndexerProperty != null
  521. && myIndexerProperty.GetIndexParameters().Length == 1
  522. && !(state_size.GetType() == typeof(Shape));
  523. }
  524. }
  525. }