You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BackendImpl.cs 41 kB

4 years ago
4 years ago
4 years ago
6 years ago
6 years ago
6 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using Tensorflow.NumPy;
  14. using System;
  15. using System.Linq;
  16. using System.Collections.Generic;
  17. using Tensorflow.Functions;
  18. using Tensorflow.Graphs;
  19. using Tensorflow.Common.Extensions;
  20. using static Tensorflow.Binding;
  21. using static Tensorflow.Graphs.SubGraphUtility;
  22. using Tensorflow.Util;
  23. using Tensorflow.Common.Types;
  24. namespace Tensorflow.Keras
  25. {
  26. public class BackendImpl : BackendBase
  27. {
  28. /* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */
  29. public Func<Array, double> py_sum = sum;
  30. public Func<Array, bool> py_all = all;
  31. //Func<Array, bool> py_any = any;
  32. //Func<double, double, double, IEnumerable<double>> py_slice = slice;
  33. public Session _SESSION => ops.get_default_session();
  34. public Graph _GRAPH;
  35. FuncGraph _CURRENT_SCRATCH_GRAPH;
  36. public Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
  37. //Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
  38. public bool _MANUAL_VAR_INIT = false;
  39. public List<string> _LOCAL_DEVICES = null;
  40. /* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */
  41. /// <summary>
  42. /// A global dictionary mapping graph objects to an index of counters used
  43. /// for various layer names in each graph.
  44. /// Allows to give unique autogenerated names to layers, in a graph-specific way.
  45. /// </summary>
  46. public Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
  47. public Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>();
  48. public Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();
  49. public _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();
  50. public BackendImpl()
  51. {
  52. }
  53. public void track_variable(IVariableV1 v)
  54. {
  55. if (tf.Context.executing_eagerly())
  56. {
  57. return;
  58. }
  59. var graph = v.Graph;
  60. if(graph is null)
  61. {
  62. graph = get_graph();
  63. }
  64. _GRAPH_VARIABLES[graph.graph_key] = v;
  65. }
  66. public Tensor placeholder(Shape shape = null,
  67. int ndim = -1,
  68. TF_DataType dtype = TF_DataType.DtInvalid,
  69. bool sparse = false,
  70. string name = null,
  71. bool ragged = false)
  72. {
  73. if (sparse)
  74. {
  75. throw new NotImplementedException("placeholder sparse is true");
  76. }
  77. else
  78. {
  79. return array_ops.placeholder(dtype: dtype, shape: shape, name: name);
  80. }
  81. }
  82. public Graph get_graph()
  83. {
  84. if (tf.Context.executing_eagerly())
  85. {
  86. if (_GRAPH == null)
  87. _GRAPH = new FuncGraph("keras_graph");
  88. return _GRAPH;
  89. }
  90. return ops.get_default_graph();
  91. }
  92. FuncGraph _scratch_graph()
  93. {
  94. if (_CURRENT_SCRATCH_GRAPH == null)
  95. _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph");
  96. return _CURRENT_SCRATCH_GRAPH;
  97. }
  98. public int get_uid(string prefix)
  99. {
  100. var graph = tf.get_default_graph();
  101. if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
  102. PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<string, int>());
  103. if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix))
  104. PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0;
  105. PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1;
  106. return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix];
  107. }
  108. public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
  109. public void clear_session()
  110. {
  111. tf.Context.reset_context();
  112. reset_uids();
  113. // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
  114. if (_GRAPH_LEARNING_PHASES != null)
  115. _GRAPH_LEARNING_PHASES.Clear();
  116. if (_GRAPH_LEARNING_PHASES != null)
  117. _GRAPH_LEARNING_PHASES.Clear();
  118. PER_GRAPH_LAYER_NAME_UIDS.Clear();
  119. _CURRENT_SCRATCH_GRAPH = null;
  120. _GRAPH = null;
  121. ops.set_default_session(tf.Session(ops.get_default_graph()));
  122. tf.enable_eager_execution();
  123. tf.Runner.ClearEagerOperationMap();
  124. GC.Collect();
  125. GC.WaitForPendingFinalizers();
  126. }
  127. public void manual_variable_initialization(bool value)
  128. {
  129. _MANUAL_VAR_INIT = value;
  130. }
  131. public Tensor mean(Tensor x, int axis = -1, bool keepdims = false)
  132. {
  133. if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL)
  134. x = math_ops.cast(x, TF_DataType.TF_FLOAT);
  135. return math_ops.reduce_mean(x, axis: axis, keepdims: false);
  136. }
  137. public GraphLearningPhase learning_phase()
  138. {
  139. var graph = tf.get_default_graph();
  140. if (_GRAPH_LEARNING_PHASES.ContainsKey(graph))
  141. {
  142. var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase");
  143. _GRAPH_LEARNING_PHASES[graph] = 0;
  144. }
  145. return _GRAPH_LEARNING_PHASES[graph];
  146. }
  147. public void set_learning_phase(bool value)
  148. {
  149. _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
  150. }
  151. public void set_value(IVariableV1 x, object value)
  152. {
  153. // TODO(Rinne): check the implementation.
  154. x.assign(value);
  155. }
  156. public void batch_set_value(List<(IVariableV1, NDArray)> tuples)
  157. {
  158. if (ops.executing_eagerly_outside_functions())
  159. {
  160. foreach (var (x, value) in tuples)
  161. x.assign(value, read_value: false);
  162. }
  163. else
  164. {
  165. throw new NotImplementedException("");
  166. }
  167. }
  168. /// <summary>
  169. /// Pads the 2nd and 3rd dimensions of a 4D tensor.
  170. /// </summary>
  171. /// <param name="x"></param>
  172. /// <param name="padding"></param>
  173. /// <param name="data_format"></param>
  174. /// <returns></returns>
  175. public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null)
  176. {
  177. if (padding == null)
  178. padding = new[,] { { 1, 1 }, { 1, 1 } };
  179. NDArray pattern;
  180. if (data_format == "channels_first")
  181. pattern = new int[,]
  182. {
  183. { 0, 0 },
  184. { 0, 0 },
  185. { padding[0][0], padding[0][1] },
  186. { padding[1][0], padding[1][1] }
  187. };
  188. else
  189. pattern = new int[,]
  190. {
  191. { 0, 0 },
  192. { padding[0][0], padding[0][1] },
  193. { padding[1][0], padding[1][1] },
  194. { 0, 0 }
  195. };
  196. return array_ops.pad(x, pattern);
  197. }
  198. /// <summary>
  199. /// Method to evaluate a tensor in eager or in a tf.function.
  200. /// </summary>
  201. /// <param name="outputs"></param>
  202. /// <returns></returns>
  203. public NDArray eval_in_eager_or_function(Tensors outputs)
  204. {
  205. if (outputs[0].op.type == "Const")
  206. return tensor_util.constant_value(outputs);
  207. var source_graph = outputs.graph;
  208. var exec_graph = _scratch_graph();
  209. var global_graph = get_graph();
  210. if (source_graph == global_graph && exec_graph != global_graph)
  211. {
  212. var lifted_map = lift_to_graph(outputs, exec_graph,
  213. new List<Tensor>(),
  214. add_sources: true,
  215. handle_captures: true,
  216. base_graph: source_graph);
  217. }
  218. if (outputs[0].op.type == "Placeholder"
  219. || outputs[0].op.type == "StridedSlice")
  220. return exec_graph.external_captures.Last().numpy();
  221. // Consolidate updates
  222. exec_graph.as_default();
  223. exec_graph.Inputs = exec_graph.internal_captures;
  224. exec_graph.Outputs = outputs;
  225. var graph_fn = new ConcreteFunction(exec_graph);
  226. _CURRENT_SCRATCH_GRAPH = null;
  227. tf.Context.restore_mode();
  228. // return outputs.eval();
  229. throw new NotImplementedException("");
  230. }
  231. public class _DummyEagerGraph
  232. { }
  233. /// <summary>
  234. /// Categorical crossentropy between an output tensor and a target tensor.
  235. /// </summary>
  236. /// <param name="target"></param>
  237. /// <param name="output"></param>
  238. /// <param name="from_logits"></param>
  239. /// <param name="axis"></param>
  240. /// <returns></returns>
  241. public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
  242. {
  243. if (from_logits)
  244. return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis);
  245. if (output.op != null && output.op.type == "Softmax")
  246. {
  247. if (output.op.inputs.Length != 1) throw new ApplicationException();
  248. var o = output.op.inputs[0];
  249. return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: o, axis: axis);
  250. }
  251. // scale preds so that the class probas of each sample sum to 1
  252. output = output / math_ops.reduce_sum(output, new Axis(axis), true);
  253. // Compute cross entropy from probabilities.
  254. var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype());
  255. output = clip_ops.clip_by_value(output, epsilon_, 1.0f - epsilon_);
  256. return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis));
  257. }
  258. public Tensor sparse_categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1, int? ignore_class = null)
  259. {
  260. target = tf.cast(target, tf.int64);
  261. if (!from_logits)
  262. {
  263. var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype());
  264. output = tf.clip_by_value(output, epsilon_, 1 - epsilon_);
  265. output = tf.math.log(output);
  266. }
  267. var output_rank = output.shape.ndim;
  268. if (output_rank > -1)
  269. {
  270. axis = Math.Abs(axis) % output_rank;
  271. if (axis != output_rank - 1)
  272. {
  273. /*var permutation = list(
  274. itertools.chain(
  275. range(axis), range(axis + 1, output_rank), [axis]
  276. )
  277. );
  278. output = tf.transpose(output, perm: permutation);*/
  279. throw new NotImplementedException("");
  280. }
  281. }
  282. var output_shape = tf.shape(output);
  283. var target_rank = target.shape.ndim;
  284. var update_shape = target_rank > -1 && output_rank > -1 && target_rank != output_rank - 1;
  285. if (update_shape)
  286. {
  287. target = tf.reshape(target, -1);
  288. output = tf.reshape(output, (-1, output.shape[-1]));
  289. }
  290. if (ignore_class.HasValue)
  291. {
  292. throw new NotImplementedException("");
  293. }
  294. var res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels: target, logits: output);
  295. if (ignore_class.HasValue)
  296. {
  297. throw new NotImplementedException("");
  298. }
  299. if (update_shape && output_rank >= 3)
  300. {
  301. // If our output includes timesteps or
  302. // spatial dimensions we need to reshape
  303. res = tf.reshape(res, output_shape[":-1"]);
  304. }
  305. return res;
  306. }
  307. public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false)
  308. {
  309. if (from_logits)
  310. return tf.nn.sigmoid_cross_entropy_with_logits(labels: target, logits: output);
  311. var epsilon_ = constant_op.constant(epsilon(), dtype: output.dtype.as_base_dtype());
  312. output = tf.clip_by_value(output, epsilon_, 1.0f - epsilon_);
  313. // Compute cross entropy from probabilities.
  314. var bce = target * tf.math.log(output + epsilon());
  315. bce += (1 - target) * tf.math.log(1 - output + epsilon());
  316. return -bce;
  317. }
  318. /// <summary>
  319. /// Resizes the images contained in a 4D tensor.
  320. /// </summary>
  321. /// <param name="x"></param>
  322. /// <param name="height_factor"></param>
  323. /// <param name="width_factor"></param>
  324. /// <param name="data_format"></param>
  325. /// <param name="interpolation"></param>
  326. /// <returns></returns>
  327. public Tensor resize_images(Tensor x, int height_factor, int width_factor,
  328. string data_format, string interpolation = "nearest")
  329. {
  330. var (rows, cols) = (0, 0);
  331. if (data_format == "channels_first")
  332. (rows, cols) = (2, 3);
  333. else if (data_format == "channels_last")
  334. (rows, cols) = (1, 2);
  335. else
  336. throw new ValueError($"Invalid `data_format` argument: {data_format}");
  337. var original_shape = x.shape;
  338. var new_shape = array_ops.shape(x)[new Slice(rows, cols + 1)];
  339. new_shape *= constant_op.constant(np.array(height_factor, width_factor));
  340. if (data_format == "channels_first")
  341. // x = permute_dimensions(x, [0, 2, 3, 1]);
  342. throw new NotImplementedException("");
  343. if (interpolation == "nearest")
  344. x = tf.image.resize_images_v2(x, new_shape, method: ResizeMethod.NEAREST_NEIGHBOR);
  345. if (data_format == "channels_first")
  346. // x = permute_dimensions(x, [0, 3, 1, 2]);
  347. throw new NotImplementedException("");
  348. int new_height = original_shape[rows] < 0 ? -1 : (int)original_shape[rows] * height_factor;
  349. int new_width = original_shape[cols] < 0 ? -1 : (int)original_shape[cols] * width_factor;
  350. Shape output_shape = data_format == "channels_first" ?
  351. (-1, -1, new_height, new_width) : (-1, new_height, new_width, -1);
  352. x.shape = output_shape;
  353. return x;
  354. }
  355. /// <summary>
  356. /// Concatenates a list of tensors alongside the specified axis.
  357. /// </summary>
  358. /// <param name="tensors">list of tensors to concatenate.</param>
  359. /// <param name="axis">concatenation axis.</param>
  360. /// <returns></returns>
  361. public Tensor concatenate(Tensors tensors, int axis = -1)
  362. {
  363. if(axis < 0)
  364. {
  365. var rank = tensors[0].ndim;
  366. if (rank > -1)
  367. axis += rank;
  368. else
  369. axis = 0;
  370. }
  371. return array_ops.concat(tensors, axis);
  372. }
  373. public Tensor conv2d_transpose(Tensor x,
  374. IVariableV1 kernel,
  375. Tensor output_shape,
  376. Shape strides = null,
  377. string padding = "valid",
  378. string data_format = null,
  379. Shape dilation_rate = null)
  380. {
  381. /*
  382. var force_transpose = false;
  383. if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 }))
  384. force_transpose = true;
  385. x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
  386. */
  387. var tf_data_format = "NHWC";
  388. padding = padding.ToUpper();
  389. strides = new Shape(1, strides[0], strides[1], 1);
  390. if (dilation_rate.Equals(new long[] { 1, 1 }))
  391. x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides,
  392. padding: padding,
  393. data_format: tf_data_format);
  394. else
  395. throw new NotImplementedException("dilation_rate other than [1,1] is not yet supported");
  396. return x;
  397. }
  398. public (Tensors, Tensors, Tensors) rnn(
  399. Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states
  400. Tensors inputs, // inputs is a tuple of tensors (one per input sequence)
  401. Tensors initial_states,
  402. bool go_backwards = false,
  403. Tensor? mask = null,
  404. Tensors? constants = null,
  405. bool unroll = false,
  406. Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not
  407. bool time_major = false,
  408. bool zero_output_for_mask = false,
  409. bool return_all_outputs = true)
  410. {
  411. Tensor swap_batch_timestep(Tensor input_t)
  412. {
  413. var axes = Enumerable.Range(0, input_t.rank).ToArray();
  414. axes[0] = 1;
  415. axes[1] = 0;
  416. return tf.transpose(input_t, axes);
  417. }
  418. if (!time_major)
  419. {
  420. inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors();
  421. }
  422. var flatted_inptus = Nest.Flatten(inputs).ToList();
  423. var first_flatted_input = flatted_inptus[0];
  424. var time_steps = first_flatted_input.shape[0];
  425. var batch = first_flatted_input.shape[1];
  426. var time_steps_t = (int)first_flatted_input.shape[0];
  427. foreach (var input_ in flatted_inptus)
  428. {
  429. input_.shape.with_rank_at_least(3);
  430. }
  431. if (mask != null)
  432. {
  433. if (mask.dtype != TF_DataType.TF_BOOL)
  434. {
  435. mask = tf.cast(mask, TF_DataType.TF_BOOL);
  436. }
  437. if (mask.rank == 2)
  438. {
  439. mask = tf.expand_dims(mask, -1);
  440. }
  441. if (!time_major)
  442. {
  443. mask = swap_batch_timestep(mask);
  444. }
  445. }
  446. // tf.where needs its condition tensor to be the same shape as its two
  447. // result tensors, but in our case the condition (mask) tensor is
  448. // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
  449. // So we need to broadcast the mask to match the shape of inputs.
  450. // That's what the tile call does, it just repeats the mask along its
  451. // second dimension n times.
  452. Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
  453. {
  454. if (!mask_t.IsSingle())
  455. {
  456. throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}");
  457. }
  458. if (!input_t.IsSingle())
  459. {
  460. throw new ValueError($"input_t is expected to be tensor, but got {input_t}");
  461. }
  462. var rank_diff = input_t.rank - mask_t.rank;
  463. for (int i = 0; i < rank_diff; i++)
  464. {
  465. mask_t = tf.expand_dims(mask_t, -1);
  466. }
  467. var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
  468. return tf.tile(mask_t, multiples);
  469. }
  470. Tensors outputs = new Tensors();
  471. Tensors output_time_zero = new Tensors();
  472. Tensors last_output = new Tensors();
  473. Tensors new_states = new Tensors();
  474. if (unroll)
  475. {
  476. if (time_steps == 0)
  477. {
  478. throw new ValueError("Unrolling requires a fixed number of timesteps.");
  479. }
  480. // Process the input tensors. The input tensor need to be split on the
  481. // time_step dim, and reverse if go_backwards is True. In the case of
  482. // nested input, the input is flattened and then transformed
  483. // individually. The result of this will be a tuple of lists, each of
  484. // the item in tuple is list of the tensor with shape (batch, feature)
  485. // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple
  486. //var states = Tuple.Create(initial_states);
  487. var states = initial_states;
  488. var successive_states = new Tensors();
  489. var successive_outputs = new Tensors();
  490. // Process the input tensors. The input tensor need to be split on the
  491. // time_step dim, and reverse if go_backwards is True. In the case of
  492. // nested input, the input is flattened and then transformed
  493. // individually. The result of this will be a tuple of lists, each of
  494. // the item in tuple is list of the tensor with shape (batch, feature)
  495. Tensors _process_single_input_t(Tensor input_t)
  496. {
  497. var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
  498. if (go_backwards)
  499. {
  500. unstaked_input_t = unstaked_input_t.Reverse().ToArray();
  501. }
  502. return unstaked_input_t;
  503. }
  504. // TODO(Wanglongzhi2001)
  505. Tensors processed_input;
  506. if (!inputs.IsSingle())
  507. {
  508. processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors();
  509. }
  510. else
  511. {
  512. processed_input = _process_single_input_t(inputs);
  513. }
  514. object _get_input_tensor(int time)
  515. {
  516. List<Tensor> inp = new List<Tensor>();
  517. foreach (var t_ in processed_input)
  518. {
  519. inp.Add(t_[time]);
  520. }
  521. return Nest.PackSequenceAs(inputs, inp);
  522. }
  523. if (mask != null)
  524. {
  525. var mask_list = tf.unstack(mask);
  526. if (go_backwards)
  527. {
  528. mask_list.Reverse();
  529. }
  530. for (int i = 0; i < time_steps; i++)
  531. {
  532. // TODO(Wanglongzhi2001),deal with _get_input_tensor
  533. var inp = _get_input_tensor(i);
  534. var mask_t = mask_list[i];
  535. // TODO
  536. var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
  537. var tiled_mask_t = _expand_mask(mask_t, output);
  538. Tensors prev_output;
  539. if (successive_outputs == null)
  540. {
  541. prev_output = tf.zeros_like(output);
  542. }
  543. else
  544. {
  545. prev_output = successive_outputs[successive_outputs.Length - 1];
  546. }
  547. output = tf.where(tiled_mask_t, output, prev_output);
  548. var flat_states = Nest.Flatten(states).ToList();
  549. var flat_new_states = Nest.Flatten(newStates).ToList();
  550. var tiledMaskT = flat_states
  551. .Select(s => _expand_mask(mask_t, s))
  552. .ToArray();
  553. var tuple = Tuple.Create(tiledMaskT);
  554. List<Tensor> flat_final_states = new List<Tensor>();
  555. foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states))
  556. {
  557. flat_final_states.Add(tf.where(m, s, ps));
  558. }
  559. states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
  560. if (return_all_outputs)
  561. {
  562. successive_outputs.Add(output);
  563. successive_states.Add(states);
  564. }
  565. else
  566. {
  567. successive_outputs = new Tensors { output };
  568. successive_states = new Tensors { states };
  569. }
  570. }
  571. last_output = successive_outputs[successive_outputs.Length - 1];
  572. new_states = successive_states[successive_states.Length - 1];
  573. outputs = tf.stack(successive_outputs);
  574. if (zero_output_for_mask)
  575. {
  576. last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
  577. outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
  578. }
  579. else // mask is null
  580. {
  581. for (int i = 0; i < time_steps; i++)
  582. {
  583. var inp = _get_input_tensor(i);
  584. var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
  585. states = newStates;
  586. if (return_all_outputs)
  587. {
  588. successive_outputs.Add(output);
  589. successive_states.Add(newStates);
  590. }
  591. else
  592. {
  593. successive_outputs = new Tensors { output };
  594. successive_states = new Tensors { newStates };
  595. }
  596. }
  597. last_output = successive_outputs[successive_outputs.Length - 1];
  598. new_states = successive_states[successive_states.Length - 1];
  599. outputs = tf.stack(successive_outputs);
  600. }
  601. }
  602. }
  603. else // unroll == false
  604. {
  605. var states = initial_states;
  606. // Create input tensor array, if the inputs is nested tensors, then it
  607. // will be flattened first, and tensor array will be created one per
  608. // flattened tensor.
  609. var input_ta = new List<TensorArray>();
  610. for (int i = 0; i < flatted_inptus.Count; i++)
  611. {
  612. input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t));
  613. }
  614. foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
  615. {
  616. if (!go_backwards)
  617. {
  618. ta.unstack(input_);
  619. }
  620. else
  621. {
  622. ta.unstack(reverse(input_, 0));
  623. }
  624. }
  625. // Get the time(0) input and compute the output for that, the output will
  626. // be used to determine the dtype of output tensor array. Don't read from
  627. // input_ta due to TensorArray clear_after_read default to True.
  628. var inps = new Tensors();
  629. foreach (var inp in flatted_inptus)
  630. {
  631. inps.Add(inp[0]);
  632. }
  633. var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
  634. // output_time_zero is used to determine the cell output shape and its
  635. // dtype. the value is discarded.
  636. (output_time_zero, _) = step_function((Tensor)input_time_zero,
  637. constants is null ? initial_states : initial_states.MergeWith(constants));
  638. int output_ta_size = return_all_outputs ? time_steps_t : 1;
  639. var output_ta = new List<TensorArray>();
  640. for (int i = 0; i < output_time_zero.ToList().Count; i++)
  641. {
  642. var Out = output_time_zero.ToList()[i];
  643. output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
  644. }
  645. var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");
  646. Func<Tensor, Tensor>? masking_fn;
  647. Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
  648. if (mask != null)
  649. {
  650. if (go_backwards)
  651. {
  652. mask = tf.reverse(mask, axis: new[] { 0 });
  653. }
  654. var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
  655. mask_ta = mask_ta.unstack(mask);
  656. masking_fn = (time) =>
  657. {
  658. return mask_ta.read(time);
  659. };
  660. compute_masked_output = (mask_t, flat_out, flat_mask) =>
  661. {
  662. var tiled_mask_t = new Tensors();
  663. foreach (var o in flat_out)
  664. {
  665. tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
  666. }
  667. Tensors res = new Tensors();
  668. foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList()))
  669. {
  670. res.Add(tf.where(m, o, fm));
  671. }
  672. return res;
  673. };
  674. }
  675. // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
  676. else if (input_length is Tensor)
  677. {
  678. if (go_backwards)
  679. {
  680. var max_len = tf.reduce_max(input_length, axis: 0);
  681. var rev_input_length = tf.subtract(max_len - 1, input_length);
  682. masking_fn = (time) =>
  683. {
  684. return tf.less(rev_input_length, time);
  685. };
  686. }
  687. else
  688. {
  689. masking_fn = (time) =>
  690. {
  691. return tf.greater(input_length, time);
  692. };
  693. }
  694. compute_masked_output = (mask_t, flat_out, flat_mask) =>
  695. {
  696. var res = new List<Tensor>();
  697. foreach (var (o, zo) in zip(flat_out, flat_mask))
  698. {
  699. res.Add(tf.where(mask_t, o, zo));
  700. }
  701. return res;
  702. };
  703. }
  704. else
  705. {
  706. masking_fn = null;
  707. }
  708. Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
  709. int parallel_iterations = 32;
  710. if (masking_fn != null)
  711. {
  712. // Mask for the T output will be base on the output of T - 1. In the
  713. // case T = 0, a zero filled tensor will be used.
  714. var flat_zero_output = new Tensors();
  715. foreach (var o in Nest.Flatten(output_time_zero))
  716. {
  717. flat_zero_output.Add(tf.zeros_like(o));
  718. }
  719. var prev_output = flat_zero_output;
  720. var output_ta_t = output_ta;
  721. Tensor _step(Tensor time)
  722. {
  723. /*
  724. RNN step function.
  725. Args:
  726. time: Current timestep value.
  727. output_ta_t: TensorArray.
  728. prev_output: tuple of outputs from time - 1.
  729. *states: List of states.
  730. Returns:
  731. Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
  732. */
  733. var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
  734. // maybe set shape
  735. // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
  736. var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
  737. var mask_t = masking_fn(time);
  738. var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
  739. // mask output
  740. var flat_output = Nest.Flatten(output).ToList();
  741. var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();
  742. // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
  743. var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);
  744. // mask states
  745. var flat_state = states.ToList();
  746. var flat_new_state = new_states_internal.ToList();
  747. foreach (var (state, new_state) in zip(flat_state, flat_new_state))
  748. {
  749. if (new_state is Tensor)
  750. {
  751. new_state.shape = state.shape;
  752. }
  753. }
  754. var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
  755. new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();
  756. var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
  757. // TODO(Wanglongzhi2001),deal with zip output_ta_t
  758. foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
  759. {
  760. output_ta_t.Add(ta.write(ta_index_to_write, Out));
  761. }
  762. new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
  763. output_ta = output_ta_t;
  764. new_states = new_states_internal;
  765. return time + 1;
  766. }
  767. var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
  768. }
  769. else
  770. {
  771. var output_ta_t = output_ta;
  772. new_states = states;
  773. Tensor _step(Tensor time)
  774. {
  775. var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
  776. // maybe set shape
  777. // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
  778. var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
  779. var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
  780. var flat_state = new_states.Flatten().ToList();
  781. var flat_new_state = new_states_internal.Flatten().ToList();
  782. foreach (var (state, new_state) in zip(flat_state, flat_new_state))
  783. {
  784. if (new_state is Tensor)
  785. {
  786. new_state.shape = state.shape;
  787. }
  788. }
  789. var flat_output = Nest.Flatten(output);
  790. var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
  791. output_ta_t = zip(output_ta_t, flat_output).Select(item =>
  792. {
  793. var (ta, out_) = item;
  794. return ta.write(ta_index_to_write, out_);
  795. }).ToList();
  796. new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
  797. output_ta = output_ta_t;
  798. new_states = new_states_internal;
  799. return time + 1;
  800. }
  801. var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
  802. }
  803. //Tensors outputs = new Tensors();
  804. foreach (var o in output_ta)
  805. {
  806. outputs.Add(o.stack());
  807. }
  808. foreach (var o in outputs)
  809. {
  810. last_output.Add(o[-1]);
  811. }
  812. outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
  813. last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();
  814. }
  815. Func<Tensor, Tensor> set_shape;
  816. set_shape = (output_) =>
  817. {
  818. if (output_ is Tensor)
  819. {
  820. var shape = output_.shape.as_int_list();
  821. if (return_all_outputs)
  822. {
  823. shape[0] = (int)time_steps;
  824. }
  825. else
  826. {
  827. shape[0] = 1;
  828. }
  829. shape[1] = (int)batch;
  830. output_.shape = shape;
  831. }
  832. return output_;
  833. };
  834. outputs = Nest.MapStructure(set_shape, outputs).ToTensors();
  835. if (!time_major)
  836. {
  837. outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors();
  838. }
  839. return (last_output, outputs, new_states);
  840. }
  841. public Tensor reverse(Tensor input, int axis)
  842. {
  843. return reverse(input, new int[] { axis });
  844. }
  845. public Tensor reverse(Tensor input, int[] axes)
  846. {
  847. return tf.reverse(input, axes);
  848. }
  849. public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false)
  850. {
  851. if (!is_ragged_output)
  852. {
  853. return output;
  854. }
  855. throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
  856. }
  857. }
  858. }