You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BackendImpl.cs 9.2 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using NumSharp;
  14. using System;
  15. using System.Collections.Generic;
  16. using static Tensorflow.Binding;
  17. namespace Tensorflow.Keras
  18. {
  19. public class BackendImpl : BackendBase
  20. {
  21. /* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */
  22. public Func<Array, double> py_sum = sum;
  23. public Func<Array, bool> py_all = all;
  24. //Func<Array, bool> py_any = any;
  25. //Func<double, double, double, IEnumerable<double>> py_slice = slice;
  26. public Session _SESSION => ops.get_default_session();
  27. public Graph _GRAPH;
  28. public Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
  29. //Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
  30. public bool _MANUAL_VAR_INIT = false;
  31. public List<string> _LOCAL_DEVICES = null;
  32. /* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */
  33. /// <summary>
  34. /// A global dictionary mapping graph objects to an index of counters used
  35. /// for various layer names in each graph.
  36. /// Allows to give unique autogenerated names to layers, in a graph-specific way.
  37. /// </summary>
  38. public Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
  39. public Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>();
  40. public Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();
  41. public _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();
  42. public BackendImpl()
  43. {
  44. }
  45. public void track_variable(IVariableV1 v)
  46. {
  47. var graph = v.Graph;
  48. _GRAPH_VARIABLES[graph.graph_key] = v;
  49. }
  50. public Tensor placeholder(TensorShape shape = null,
  51. int ndim = -1,
  52. TF_DataType dtype = TF_DataType.DtInvalid,
  53. bool sparse = false,
  54. string name = null,
  55. bool ragged = false)
  56. {
  57. if (sparse)
  58. {
  59. throw new NotImplementedException("placeholder sparse is true");
  60. }
  61. else
  62. {
  63. return array_ops.placeholder(dtype: dtype, shape: shape, name: name);
  64. }
  65. }
  66. public Graph get_graph()
  67. {
  68. return ops.get_default_graph();
  69. }
  70. public int get_uid(string prefix)
  71. {
  72. var graph = tf.get_default_graph();
  73. if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
  74. PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<string, int>());
  75. if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix))
  76. PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0;
  77. PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1;
  78. return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix];
  79. }
  80. public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
  81. public void clear_session()
  82. {
  83. ops.reset_default_graph();
  84. reset_uids();
  85. ops.set_default_session(tf.Session(ops.get_default_graph()));
  86. var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
  87. _GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>();
  88. _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
  89. }
  90. public void manual_variable_initialization(bool value)
  91. {
  92. _MANUAL_VAR_INIT = value;
  93. }
  94. public GraphLearningPhase learning_phase()
  95. {
  96. var graph = tf.get_default_graph();
  97. if (_GRAPH_LEARNING_PHASES.ContainsKey(graph))
  98. {
  99. var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase");
  100. _GRAPH_LEARNING_PHASES[graph] = 0;
  101. }
  102. return _GRAPH_LEARNING_PHASES[graph];
  103. }
  104. public void set_learning_phase(bool value)
  105. {
  106. _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
  107. }
  108. /// <summary>
  109. /// Pads the 2nd and 3rd dimensions of a 4D tensor.
  110. /// </summary>
  111. /// <param name="x"></param>
  112. /// <param name="padding"></param>
  113. /// <param name="data_format"></param>
  114. /// <returns></returns>
  115. public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null)
  116. {
  117. if (padding == null)
  118. padding = new[,] { { 1, 1 }, { 1, 1 } };
  119. NDArray pattern;
  120. if (data_format == "channels_first")
  121. pattern = new int[,]
  122. {
  123. { 0, 0 },
  124. { 0, 0 },
  125. { padding[0][0], padding[0][1] },
  126. { padding[1][0], padding[1][1] }
  127. };
  128. else
  129. pattern = new int[,]
  130. {
  131. { 0, 0 },
  132. { padding[0][0], padding[0][1] },
  133. { padding[1][0], padding[1][1] },
  134. { 0, 0 }
  135. };
  136. return array_ops.pad(x, pattern);
  137. }
  138. /// <summary>
  139. /// Method to evaluate a tensor in eager or in a tf.function.
  140. /// </summary>
  141. /// <param name="outputs"></param>
  142. /// <returns></returns>
  143. public NDArray eval_in_eager_or_function(Tensor outputs)
  144. {
  145. return outputs.eval();
  146. }
  147. public class _DummyEagerGraph
  148. { }
  149. /// <summary>
  150. /// Categorical crossentropy between an output tensor and a target tensor.
  151. /// </summary>
  152. /// <param name="target"></param>
  153. /// <param name="output"></param>
  154. /// <param name="from_logits"></param>
  155. /// <param name="axis"></param>
  156. /// <returns></returns>
  157. public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
  158. {
  159. if (from_logits)
  160. return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis);
  161. throw new NotImplementedException("");
  162. }
  163. /// <summary>
  164. /// Resizes the images contained in a 4D tensor.
  165. /// </summary>
  166. /// <param name="x"></param>
  167. /// <param name="height_factor"></param>
  168. /// <param name="width_factor"></param>
  169. /// <param name="data_format"></param>
  170. /// <param name="interpolation"></param>
  171. /// <returns></returns>
  172. public Tensor resize_images(Tensor x, int height_factor, int width_factor,
  173. string data_format, string interpolation = "nearest")
  174. {
  175. var (rows, cols) = (0, 0);
  176. if (data_format == "channels_first")
  177. (rows, cols) = (2, 3);
  178. else if (data_format == "channels_last")
  179. (rows, cols) = (1, 2);
  180. else
  181. throw new ValueError($"Invalid `data_format` argument: {data_format}");
  182. var original_shape = x.shape;
  183. var new_shape = array_ops.shape(x)[new Slice(rows, cols + 1)];
  184. new_shape *= constant_op.constant(np.array(height_factor, width_factor));
  185. if (data_format == "channels_first")
  186. // x = permute_dimensions(x, [0, 2, 3, 1]);
  187. throw new NotImplementedException("");
  188. if (interpolation == "nearest")
  189. x = tf.image.resize_images_v2(x, new_shape, method: ResizeMethod.NEAREST_NEIGHBOR);
  190. if (data_format == "channels_first")
  191. // x = permute_dimensions(x, [0, 3, 1, 2]);
  192. throw new NotImplementedException("");
  193. int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor;
  194. int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor;
  195. TensorShape output_shape = data_format == "channels_first" ?
  196. (-1, -1, new_height, new_width) : (-1, new_height, new_width, -1);
  197. x.set_shape(output_shape);
  198. return x;
  199. }
  200. }
  201. }