You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RNN.cs 39 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055
  1. <<<<<<< HEAD
  2. using OneOf;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Reflection;
  6. using Tensorflow.Keras.ArgsDefinition;
  7. =======
  8. using System;
  9. using System.Collections;
  10. using System.Collections.Generic;
  11. using System.Reflection;
  12. using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;
  13. >>>>>>> master
  14. using Tensorflow.Keras.ArgsDefinition.Rnn;
  15. using Tensorflow.Keras.Engine;
  16. using Tensorflow.Keras.Saving;
  17. using Tensorflow.Util;
  18. <<<<<<< HEAD
  19. using Tensorflow.Common.Extensions;
  20. using System.Linq.Expressions;
  21. using Tensorflow.Keras.Utils;
  22. using Tensorflow.Common.Types;
  23. =======
  24. using OneOf;
  25. using OneOf.Types;
  26. using Tensorflow.Common.Extensions;
  27. >>>>>>> master
  28. // from tensorflow.python.distribute import distribution_strategy_context as ds_context;
  29. namespace Tensorflow.Keras.Layers.Rnn
  30. {
  31. /// <summary>
  32. /// Base class for recurrent layers.
  33. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  34. /// for details about the usage of RNN API.
  35. /// </summary>
  36. public class RNN : RnnBase
  37. {
  38. <<<<<<< HEAD
  39. private RNNArgs _args;
  40. private object _input_spec = null; // or NoneValue??
  41. private object _state_spec = null;
  42. private Tensors _states = null;
  43. private object _constants_spec = null;
  44. private int _num_constants;
  45. protected IVariableV1 _kernel;
  46. protected IVariableV1 _bias;
  47. protected IRnnCell _cell;
  48. =======
  49. private RNNArgs args;
  50. private object input_spec = null; // or NoneValue??
  51. private object state_spec = null;
  52. private Tensors _states = null;
  53. private object constants_spec = null;
  54. private int _num_constants = 0;
  55. protected IVariableV1 kernel;
  56. protected IVariableV1 bias;
  57. protected ILayer cell;
  58. >>>>>>> master
  59. public RNN(RNNArgs args) : base(PreConstruct(args))
  60. {
  61. _args = args;
  62. SupportsMasking = true;
  63. // if is StackedRnncell
  64. <<<<<<< HEAD
  65. if (args.Cells != null)
  66. {
  67. _cell = new StackedRNNCells(new StackedRNNCellsArgs
  68. {
  69. Cells = args.Cells
  70. =======
  71. if (args.Cell.IsT0)
  72. {
  73. cell = new StackedRNNCells(new StackedRNNCellsArgs
  74. {
  75. Cells = args.Cell.AsT0,
  76. >>>>>>> master
  77. });
  78. }
  79. else
  80. {
  81. <<<<<<< HEAD
  82. _cell = args.Cell;
  83. }
  84. =======
  85. cell = args.Cell.AsT1;
  86. }
  87. Type type = cell.GetType();
  88. MethodInfo callMethodInfo = type.GetMethod("Call");
  89. if (callMethodInfo == null)
  90. {
  91. throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. ");
  92. }
  93. PropertyInfo state_size_info = type.GetProperty("state_size");
  94. if (state_size_info == null)
  95. {
  96. throw new ValueError(@"The RNN cell should have a `state_size` attribute");
  97. }
  98. // get input_shape
  99. this.args = PreConstruct(args);
  100. // The input shape is unknown yet, it could have nested tensor inputs, and
  101. // the input spec will be the list of specs for nested inputs, the structure
  102. // of the input_spec will be the same as the input.
  103. >>>>>>> master
  104. // get input_shape
  105. _args = PreConstruct(args);
  106. _num_constants = 0;
  107. }
  108. // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
  109. // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
  110. public Tensors States
  111. {
  112. get
  113. {
  114. if (_states == null)
  115. {
  116. // CHECK(Rinne): check if this is correct.
  117. var nested = _cell.StateSize.MapStructure<Tensor?>(x => null);
  118. _states = nested.AsNest().ToTensors();
  119. }
  120. return _states;
  121. }
  122. set { _states = value; }
  123. }
  124. private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
  125. {
  126. var batch = input_shape[0];
  127. var time_step = input_shape[1];
  128. if (_args.TimeMajor)
  129. {
  130. (batch, time_step) = (time_step, batch);
  131. }
  132. // state_size is a array of ints or a positive integer
  133. var state_size = _cell.StateSize.ToSingleShape();
  134. // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
  135. Func<Shape, Shape> _get_output_shape;
  136. _get_output_shape = (flat_output_size) =>
  137. {
  138. var output_dim = flat_output_size.as_int_list();
  139. Shape output_shape;
  140. if (_args.ReturnSequences)
  141. {
  142. if (_args.TimeMajor)
  143. {
  144. output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
  145. }
  146. else
  147. {
  148. output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
  149. }
  150. }
  151. else
  152. {
  153. output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
  154. }
  155. return output_shape;
  156. };
  157. Type type = _cell.GetType();
  158. PropertyInfo output_size_info = type.GetProperty("output_size");
  159. Shape output_shape;
  160. if (output_size_info != null)
  161. {
  162. output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape());
  163. // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
  164. output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
  165. }
  166. else
  167. {
  168. output_shape = _get_output_shape(state_size);
  169. }
  170. if (_args.ReturnState)
  171. {
  172. Func<Shape, Shape> _get_state_shape;
  173. _get_state_shape = (flat_state) =>
  174. {
  175. var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
  176. return new Shape(state_shape);
  177. };
  178. var state_shape = _get_state_shape(state_size);
  179. return new List<Shape> { output_shape, state_shape };
  180. }
  181. else
  182. {
  183. return output_shape;
  184. }
  185. }
  186. private Tensors compute_mask(Tensors inputs, Tensors mask)
  187. {
  188. // Time step masks must be the same for each input.
  189. // This is because the mask for an RNN is of size [batch, time_steps, 1],
  190. // and specifies which time steps should be skipped, and a time step
  191. // must be skipped for all inputs.
  192. mask = nest.flatten(mask)[0];
  193. var output_mask = _args.ReturnSequences ? mask : null;
  194. if (_args.ReturnState)
  195. {
  196. var state_mask = new List<Tensor>();
  197. for (int i = 0; i < len(States); i++)
  198. {
  199. state_mask.Add(null);
  200. }
  201. return new List<Tensor> { output_mask }.concat(state_mask);
  202. }
  203. else
  204. {
  205. return output_mask;
  206. }
  207. }
  208. // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
  209. // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
  210. public Tensors States
  211. {
  212. get
  213. {
  214. if (_states == null)
  215. {
  216. var state = nest.map_structure(x => null, cell.state_size);
  217. return nest.is_nested(state) ? state : new Tensors { state };
  218. }
  219. return _states;
  220. }
  221. set { _states = value; }
  222. }
  223. private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
  224. {
  225. var batch = input_shape[0];
  226. var time_step = input_shape[1];
  227. if (args.TimeMajor)
  228. {
  229. (batch, time_step) = (time_step, batch);
  230. }
  231. // state_size is a array of ints or a positive integer
  232. var state_size = cell.state_size;
  233. // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
  234. Func<Shape, Shape> _get_output_shape;
  235. _get_output_shape = (flat_output_size) =>
  236. {
  237. var output_dim = flat_output_size.as_int_list();
  238. Shape output_shape;
  239. if (args.ReturnSequences)
  240. {
  241. if (args.TimeMajor)
  242. {
  243. output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
  244. }
  245. else
  246. {
  247. output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
  248. }
  249. }
  250. else
  251. {
  252. output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
  253. }
  254. return output_shape;
  255. };
  256. Type type = cell.GetType();
  257. PropertyInfo output_size_info = type.GetProperty("output_size");
  258. Shape output_shape;
  259. if (output_size_info != null)
  260. {
  261. output_shape = nest.map_structure(_get_output_shape, cell.output_size);
  262. // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
  263. output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
  264. }
  265. else
  266. {
  267. output_shape = _get_output_shape(state_size[0]);
  268. }
  269. if (args.ReturnState)
  270. {
  271. Func<Shape, Shape> _get_state_shape;
  272. _get_state_shape = (flat_state) =>
  273. {
  274. var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
  275. return new Shape(state_shape);
  276. };
  277. var state_shape = _get_state_shape(new Shape(state_size.ToArray()));
  278. return new List<Shape> { output_shape, state_shape };
  279. }
  280. else
  281. {
  282. return output_shape;
  283. }
  284. }
  285. private Tensors compute_mask(Tensors inputs, Tensors mask)
  286. {
  287. // Time step masks must be the same for each input.
  288. // This is because the mask for an RNN is of size [batch, time_steps, 1],
  289. // and specifies which time steps should be skipped, and a time step
  290. // must be skipped for all inputs.
  291. mask = nest.flatten(mask)[0];
  292. var output_mask = args.ReturnSequences ? mask : null;
  293. if (args.ReturnState)
  294. {
  295. var state_mask = new List<Tensor>();
  296. for (int i = 0; i < len(States); i++)
  297. {
  298. state_mask.Add(null);
  299. }
  300. return new List<Tensor> { output_mask }.concat(state_mask);
  301. }
  302. else
  303. {
  304. return output_mask;
  305. }
  306. }
  307. public override void build(KerasShapesWrapper input_shape)
  308. {
  309. object get_input_spec(Shape shape)
  310. <<<<<<< HEAD
  311. =======
  312. {
  313. var input_spec_shape = shape.as_int_list();
  314. var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1);
  315. if (!args.Stateful)
  316. {
  317. input_spec_shape[batch_index] = -1;
  318. }
  319. input_spec_shape[time_step_index] = -1;
  320. return new InputSpec(shape: input_spec_shape);
  321. }
  322. Shape get_step_input_shape(Shape shape)
  323. {
  324. // return shape[1:] if self.time_major else (shape[0],) + shape[2:]
  325. if (args.TimeMajor)
  326. {
  327. return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
  328. }
  329. else
  330. {
  331. return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
  332. }
  333. }
  334. object get_state_spec(Shape shape)
  335. {
  336. var state_spec_shape = shape.as_int_list();
  337. // append bacth dim
  338. state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
  339. return new InputSpec(shape: state_spec_shape);
  340. }
  341. // Check whether the input shape contains any nested shapes. It could be
  342. // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
  343. // numpy inputs.
  344. if (!cell.Built)
  345. >>>>>>> master
  346. {
  347. var input_spec_shape = shape.as_int_list();
  348. var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1);
  349. if (!_args.Stateful)
  350. {
  351. input_spec_shape[batch_index] = -1;
  352. }
  353. input_spec_shape[time_step_index] = -1;
  354. return new InputSpec(shape: input_spec_shape);
  355. }
  356. Shape get_step_input_shape(Shape shape)
  357. {
  358. // return shape[1:] if self.time_major else (shape[0],) + shape[2:]
  359. if (_args.TimeMajor)
  360. {
  361. return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
  362. }
  363. else
  364. {
  365. return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
  366. }
  367. }
  368. object get_state_spec(Shape shape)
  369. {
  370. var state_spec_shape = shape.as_int_list();
  371. // append bacth dim
  372. state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
  373. return new InputSpec(shape: state_spec_shape);
  374. }
  375. // Check whether the input shape contains any nested shapes. It could be
  376. // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
  377. // numpy inputs.
  378. if (!_cell.Built)
  379. {
  380. _cell.build(input_shape);
  381. }
  382. }
  383. <<<<<<< HEAD
  384. /// <summary>
  385. ///
  386. /// </summary>
  387. /// <param name="inputs"></param>
  388. /// <param name="mask">Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked</param>
  389. /// <param name="training"></param>
  390. /// <param name="initial_state">List of initial state tensors to be passed to the first call of the cell</param>
  391. /// <param name="constants">List of constant tensors to be passed to the cell at each timestep</param>
  392. /// <returns></returns>
  393. /// <exception cref="ValueError"></exception>
  394. /// <exception cref="NotImplementedException"></exception>
  395. protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
  396. {
  397. RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
  398. if(optional_args is not null && rnn_optional_args is null)
  399. {
  400. throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`");
  401. }
  402. Tensors? constants = rnn_optional_args?.Constants;
  403. Tensors? mask = rnn_optional_args?.Mask;
  404. //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
  405. // 暂时先不接受ragged tensor
  406. int row_length = 0; // TODO(Rinne): support this param.
  407. =======
  408. // inputs: Tensors
  409. // mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked
  410. // training: bool
  411. // initial_state: List of initial state tensors to be passed to the first call of the cell
  412. // constants: List of constant tensors to be passed to the cell at each timestep
  413. protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
  414. {
  415. //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
  416. // 暂时先不接受ragged tensor
  417. int? row_length = null;
  418. >>>>>>> master
  419. bool is_ragged_input = false;
  420. _validate_args_if_ragged(is_ragged_input, mask);
  421. (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
  422. <<<<<<< HEAD
  423. _maybe_reset_cell_dropout_mask(_cell);
  424. if (_cell is StackedRNNCells)
  425. {
  426. var stack_cell = _cell as StackedRNNCells;
  427. foreach (IRnnCell cell in stack_cell.Cells)
  428. =======
  429. _maybe_reset_cell_dropout_mask(cell);
  430. if (cell is StackedRNNCells)
  431. {
  432. var stack_cell = cell as StackedRNNCells;
  433. foreach (var cell in stack_cell.Cells)
  434. >>>>>>> master
  435. {
  436. _maybe_reset_cell_dropout_mask(cell);
  437. }
  438. }
  439. if (mask != null)
  440. {
  441. // Time step masks must be the same for each input.
  442. <<<<<<< HEAD
  443. mask = mask.Flatten().First();
  444. }
  445. Shape input_shape;
  446. if (!inputs.IsNested())
  447. =======
  448. mask = nest.flatten(mask)[0];
  449. }
  450. Shape input_shape;
  451. if (nest.is_nested(inputs))
  452. >>>>>>> master
  453. {
  454. // In the case of nested input, use the first element for shape check
  455. // input_shape = nest.flatten(inputs)[0].shape;
  456. // TODO(Wanglongzhi2001)
  457. <<<<<<< HEAD
  458. input_shape = inputs.Flatten().First().shape;
  459. =======
  460. input_shape = nest.flatten(inputs)[0].shape;
  461. >>>>>>> master
  462. }
  463. else
  464. {
  465. input_shape = inputs.shape;
  466. }
  467. <<<<<<< HEAD
  468. var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
  469. if (_args.Unroll && timesteps == null)
  470. =======
  471. var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];
  472. if (args.Unroll && timesteps != null)
  473. >>>>>>> master
  474. {
  475. throw new ValueError(
  476. "Cannot unroll a RNN if the " +
  477. "time dimension is undefined. \n" +
  478. "- If using a Sequential model, " +
  479. "specify the time dimension by passing " +
  480. "an `input_shape` or `batch_input_shape` " +
  481. "argument to your first layer. If your " +
  482. "first layer is an Embedding, you can " +
  483. "also use the `input_length` argument.\n" +
  484. "- If using the functional API, specify " +
  485. "the time dimension by passing a `shape` " +
  486. "or `batch_shape` argument to your Input layer."
  487. );
  488. }
  489. // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
  490. <<<<<<< HEAD
  491. Func<Tensors, Tensors, (Tensors, Tensors)> step;
  492. bool is_tf_rnn_cell = _cell.IsTFRnnCell;
  493. if (constants is not null)
  494. {
  495. if (!_cell.SupportOptionalArgs)
  496. {
  497. throw new ValueError(
  498. $"RNN cell {_cell} does not support constants." +
  499. =======
  500. var cell_call_fn = cell.Call;
  501. Func<Tensors, Tensors, (Tensors, Tensors)> step;
  502. if (constants != null)
  503. {
  504. ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters();
  505. bool hasParam = parameters.Any(p => p.Name == "constants");
  506. if (!hasParam)
  507. {
  508. throw new ValueError(
  509. $"RNN cell {cell} does not support constants." +
  510. >>>>>>> master
  511. $"Received: constants={constants}");
  512. }
  513. step = (inputs, states) =>
  514. {
  515. <<<<<<< HEAD
  516. constants = new Tensors(states.TakeLast(_num_constants));
  517. states = new Tensors(states.SkipLast(_num_constants));
  518. states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
  519. var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
  520. return (output, new_states.Single);
  521. =======
  522. // constants = states[-self._num_constants :]
  523. constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)];
  524. // states = states[: -self._num_constants]
  525. states = states.numpy()[new Slice(0, states.Length - _num_constants)];
  526. // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
  527. states = states.Length == 1 ? states[0] : states;
  528. var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
  529. // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
  530. if (!nest.is_nested(new_states))
  531. {
  532. return (output, new Tensors { new_states });
  533. }
  534. return (output, new_states);
  535. >>>>>>> master
  536. };
  537. }
  538. else
  539. {
  540. step = (inputs, states) =>
  541. {
  542. <<<<<<< HEAD
  543. states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
  544. var (output, new_states) = _cell.Apply(inputs, states);
  545. return (output, new_states);
  546. };
  547. }
  548. var (last_output, outputs, states) = keras.backend.rnn(
  549. step,
  550. inputs,
  551. initial_state,
  552. constants: constants,
  553. go_backwards: _args.GoBackwards,
  554. mask: mask,
  555. unroll: _args.Unroll,
  556. input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
  557. time_major: _args.TimeMajor,
  558. zero_output_for_mask: _args.ZeroOutputForMask,
  559. return_all_outputs: _args.ReturnSequences);
  560. if (_args.Stateful)
  561. {
  562. throw new NotImplementedException("this argument havn't been developed.");
  563. }
  564. Tensors output = new Tensors();
  565. if (_args.ReturnSequences)
  566. {
  567. // TODO(Rinne): add go_backwards parameter and revise the `row_length` param
  568. output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
  569. =======
  570. // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
  571. states = states.Length == 1 ? states[0] : states;
  572. var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
  573. if (!nest.is_nested(new_states))
  574. {
  575. return (output, new Tensors { new_states });
  576. }
  577. return (output, new_states);
  578. };
  579. }
  580. var (last_output, outputs, states) = BackendImpl.rnn(step,
  581. inputs,
  582. initial_state,
  583. constants: constants,
  584. go_backwards: args.GoBackwards,
  585. mask: mask,
  586. unroll: args.Unroll,
  587. input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
  588. time_major: args.TimeMajor,
  589. zero_output_for_mask: args.ZeroOutputForMask,
  590. return_all_outputs: args.ReturnSequences);
  591. if (args.Stateful)
  592. {
  593. throw new NotImplementedException("this argument havn't been developed!");
  594. }
  595. Tensors output = new Tensors();
  596. if (args.ReturnSequences)
  597. {
  598. throw new NotImplementedException("this argument havn't been developed!");
  599. >>>>>>> master
  600. }
  601. else
  602. {
  603. output = last_output;
  604. }
  605. <<<<<<< HEAD
  606. if (_args.ReturnState)
  607. {
  608. =======
  609. if (args.ReturnState)
  610. {
  611. >>>>>>> master
  612. foreach (var state in states)
  613. {
  614. output.Add(state);
  615. }
  616. return output;
  617. }
  618. else
  619. {
  620. return output;
  621. }
  622. }
  623. <<<<<<< HEAD
  624. public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
  625. {
  626. RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
  627. if (optional_args is not null && rnn_optional_args is null)
  628. {
  629. throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`.");
  630. }
  631. Tensors? constants = rnn_optional_args?.Constants;
  632. (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants);
  633. if(initial_states is null && constants is null)
  634. {
  635. return base.Apply(inputs);
  636. }
  637. // TODO(Rinne): implement it.
  638. throw new NotImplementedException();
  639. }
  640. private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
  641. {
  642. if (inputs.Length > 1)
  643. {
  644. if (_num_constants != 0)
  645. {
  646. initial_state = new Tensors(inputs.Skip(1));
  647. }
  648. else
  649. {
  650. initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants));
  651. constants = new Tensors(inputs.TakeLast(_num_constants));
  652. =======
  653. private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants)
  654. {
  655. if (nest.is_sequence(input))
  656. {
  657. if (_num_constants != 0)
  658. {
  659. initial_state = inputs[new Slice(1, len(inputs))];
  660. }
  661. else
  662. {
  663. initial_state = inputs[new Slice(1, len(inputs) - _num_constants)];
  664. constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))];
  665. >>>>>>> master
  666. }
  667. if (len(initial_state) == 0)
  668. initial_state = null;
  669. inputs = inputs[0];
  670. }
  671. <<<<<<< HEAD
  672. if (_args.Stateful)
  673. =======
  674. if (args.Stateful)
  675. >>>>>>> master
  676. {
  677. if (initial_state != null)
  678. {
  679. var tmp = new Tensor[] { };
  680. foreach (var s in nest.flatten(States))
  681. {
  682. <<<<<<< HEAD
  683. tmp.add(tf.math.count_nonzero(s.Single()));
  684. }
  685. var non_zero_count = tf.add_n(tmp);
  686. //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
  687. if ((int)non_zero_count.numpy() > 0)
  688. =======
  689. tmp.add(tf.math.count_nonzero((Tensor)s));
  690. }
  691. var non_zero_count = tf.add_n(tmp);
  692. //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
  693. if((int)non_zero_count.numpy() > 0)
  694. >>>>>>> master
  695. {
  696. initial_state = States;
  697. }
  698. }
  699. else
  700. {
  701. initial_state = States;
  702. }
  703. <<<<<<< HEAD
  704. // TODO(Wanglongzhi2001),
  705. // initial_state = tf.nest.map_structure(
  706. //# When the layer has a inferred dtype, use the dtype from the
  707. //# cell.
  708. // lambda v: tf.cast(
  709. // v, self.compute_dtype or self.cell.compute_dtype
  710. // ),
  711. // initial_state,
  712. // )
  713. }
  714. else if (initial_state is null)
  715. =======
  716. }
  717. else if(initial_state != null)
  718. >>>>>>> master
  719. {
  720. initial_state = get_initial_state(inputs);
  721. }
  722. if (initial_state.Length != States.Length)
  723. {
  724. <<<<<<< HEAD
  725. throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
  726. $"but it received {initial_state.Length} " +
  727. $"initial state(s). Input received: {inputs}");
  728. =======
  729. throw new ValueError(
  730. $"Layer {this} expects {States.Length} state(s), " +
  731. $"but it received {initial_state.Length} " +
  732. $"initial state(s). Input received: {inputs}");
  733. >>>>>>> master
  734. }
  735. return (inputs, initial_state, constants);
  736. }
  737. private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
  738. {
  739. <<<<<<< HEAD
  740. if (!is_ragged_input)
  741. =======
  742. if (!is_ragged_input)
  743. >>>>>>> master
  744. {
  745. return;
  746. }
  747. <<<<<<< HEAD
  748. if (_args.Unroll)
  749. =======
  750. if (args.Unroll)
  751. >>>>>>> master
  752. {
  753. throw new ValueError("The input received contains RaggedTensors and does " +
  754. "not support unrolling. Disable unrolling by passing " +
  755. "`unroll=False` in the RNN Layer constructor.");
  756. }
  757. if (mask != null)
  758. {
  759. throw new ValueError($"The mask that was passed in was {mask}, which " +
  760. "cannot be applied to RaggedTensor inputs. Please " +
  761. "make sure that there is no mask injected by upstream " +
  762. "layers.");
  763. }
  764. }
  765. void _maybe_reset_cell_dropout_mask(ILayer cell)
  766. {
  767. <<<<<<< HEAD
  768. if (cell is DropoutRNNCellMixin CellDRCMixin)
  769. {
  770. CellDRCMixin.reset_dropout_mask();
  771. CellDRCMixin.reset_recurrent_dropout_mask();
  772. }
  773. =======
  774. //if (cell is DropoutRNNCellMixin)
  775. //{
  776. // cell.reset_dropout_mask();
  777. // cell.reset_recurrent_dropout_mask();
  778. //}
  779. >>>>>>> master
  780. }
  781. private static RNNArgs PreConstruct(RNNArgs args)
  782. {
  783. if (args.Kwargs == null)
  784. {
  785. args.Kwargs = new Dictionary<string, object>();
  786. }
  787. // If true, the output for masked timestep will be zeros, whereas in the
  788. // false case, output from previous timestep is returned for masked timestep.
  789. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
  790. Shape input_shape;
  791. var propIS = (Shape)args.Kwargs.Get("input_shape", null);
  792. var propID = (int?)args.Kwargs.Get("input_dim", null);
  793. var propIL = (int?)args.Kwargs.Get("input_length", null);
  794. if (propIS == null && (propID != null || propIL != null))
  795. {
  796. input_shape = new Shape(
  797. propIL ?? -1,
  798. propID ?? -1);
  799. args.Kwargs["input_shape"] = input_shape;
  800. }
  801. return args;
  802. }
  803. public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
  804. {
  805. throw new NotImplementedException();
  806. <<<<<<< HEAD
  807. =======
  808. }
  809. // 好像不能cell不能传接口类型
  810. //public RNN New(IRnnArgCell cell,
  811. // bool return_sequences = false,
  812. // bool return_state = false,
  813. // bool go_backwards = false,
  814. // bool stateful = false,
  815. // bool unroll = false,
  816. // bool time_major = false)
  817. // => new RNN(new RNNArgs
  818. // {
  819. // Cell = cell,
  820. // ReturnSequences = return_sequences,
  821. // ReturnState = return_state,
  822. // GoBackwards = go_backwards,
  823. // Stateful = stateful,
  824. // Unroll = unroll,
  825. // TimeMajor = time_major
  826. // });
  827. //public RNN New(List<IRnnArgCell> cell,
  828. // bool return_sequences = false,
  829. // bool return_state = false,
  830. // bool go_backwards = false,
  831. // bool stateful = false,
  832. // bool unroll = false,
  833. // bool time_major = false)
  834. // => new RNN(new RNNArgs
  835. // {
  836. // Cell = cell,
  837. // ReturnSequences = return_sequences,
  838. // ReturnState = return_state,
  839. // GoBackwards = go_backwards,
  840. // Stateful = stateful,
  841. // Unroll = unroll,
  842. // TimeMajor = time_major
  843. // });
  844. protected Tensors get_initial_state(Tensor inputs)
  845. {
  846. Type type = cell.GetType();
  847. MethodInfo MethodInfo = type.GetMethod("get_initial_state");
  848. if (nest.is_nested(inputs))
  849. {
  850. // The input are nested sequences. Use the first element in the seq
  851. // to get batch size and dtype.
  852. inputs = nest.flatten(inputs)[0];
  853. }
  854. var input_shape = tf.shape(inputs);
  855. var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0];
  856. var dtype = inputs.dtype;
  857. Tensor init_state;
  858. if (MethodInfo != null)
  859. {
  860. init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype });
  861. }
  862. else
  863. {
  864. init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype);
  865. }
  866. //if (!nest.is_nested(init_state))
  867. //{
  868. // init_state = new List<Tensor> { init_state};
  869. //}
  870. return new List<Tensor> { init_state };
  871. //return _generate_zero_filled_state_for_cell(null, null);
  872. >>>>>>> master
  873. }
  874. // 好像不能cell不能传接口类型
  875. //public RNN New(IRnnArgCell cell,
  876. // bool return_sequences = false,
  877. // bool return_state = false,
  878. // bool go_backwards = false,
  879. // bool stateful = false,
  880. // bool unroll = false,
  881. // bool time_major = false)
  882. // => new RNN(new RNNArgs
  883. // {
  884. // Cell = cell,
  885. // ReturnSequences = return_sequences,
  886. // ReturnState = return_state,
  887. // GoBackwards = go_backwards,
  888. // Stateful = stateful,
  889. // Unroll = unroll,
  890. // TimeMajor = time_major
  891. // });
  892. //public RNN New(List<IRnnArgCell> cell,
  893. // bool return_sequences = false,
  894. // bool return_state = false,
  895. // bool go_backwards = false,
  896. // bool stateful = false,
  897. // bool unroll = false,
  898. // bool time_major = false)
  899. // => new RNN(new RNNArgs
  900. // {
  901. // Cell = cell,
  902. // ReturnSequences = return_sequences,
  903. // ReturnState = return_state,
  904. // GoBackwards = go_backwards,
  905. // Stateful = stateful,
  906. // Unroll = unroll,
  907. // TimeMajor = time_major
  908. // });
  909. protected Tensors get_initial_state(Tensors inputs)
  910. {
  911. var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state");
  912. var input = inputs[0];
  913. var input_shape = inputs.shape;
  914. var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
  915. var dtype = input.dtype;
  916. Tensors init_state = new Tensors();
  917. if(get_initial_state_fn != null)
  918. {
  919. init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype });
  920. }
  921. //if (_cell is RnnCellBase rnn_base_cell)
  922. //{
  923. // init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
  924. //}
  925. else
  926. {
  927. init_state = RnnUtils.generate_zero_filled_state(tf.convert_to_tensor(batch_size), _cell.StateSize, dtype);
  928. }
  929. return init_state;
  930. }
  931. // Check whether the state_size contains multiple states.
  932. <<<<<<< HEAD
  933. public static bool is_multiple_state(GeneralizedTensorShape state_size)
  934. =======
  935. public static bool is_multiple_state(object state_size)
  936. >>>>>>> master
  937. {
  938. return state_size.Shapes.Length > 1;
  939. }
  940. }
  941. }