You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_layer_utils.cs 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using NumSharp;
  14. using System;
  15. using System.Collections.Generic;
  16. using System.Linq;
  17. using System.Reflection;
  18. using Tensorflow.Keras.ArgsDefinition;
  19. using Tensorflow.Keras.Engine;
  20. using static Tensorflow.Binding;
  21. using static Tensorflow.KerasApi;
  22. namespace Tensorflow.Keras.Utils
  23. {
  24. public class base_layer_utils
  25. {
  26. /// <summary>
  27. /// Adds a new variable to the layer.
  28. /// </summary>
  29. /// <param name="args"></param>
  30. /// <returns></returns>
  31. public static IVariableV1 make_variable(VariableArgs args)
  32. {
  33. #pragma warning disable CS0219 // Variable is assigned but its value is never used
  34. var initializing_from_value = false;
  35. #pragma warning restore CS0219 // Variable is assigned but its value is never used
  36. Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs(args.Shape, dtype: args.DType));
  37. var variable_dtype = args.DType.as_base_dtype();
  38. return tf.Variable(init_val,
  39. dtype: variable_dtype,
  40. shape: args.Shape,
  41. name: args.Name,
  42. trainable: args.Trainable,
  43. validate_shape: args.ValidateShape,
  44. use_resource: args.UseResource);
  45. }
  46. /// <summary>
  47. /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
  48. /// </summary>
  49. /// <param name="name"></param>
  50. /// <returns></returns>
  51. public static string unique_layer_name(string name, Dictionary<string, int> name_uid_map = null,
  52. string[] avoid_names = null, bool zero_based = false)
  53. {
  54. if (name_uid_map == null)
  55. name_uid_map = get_default_graph_uid_map();
  56. if (avoid_names == null)
  57. avoid_names = new string[0];
  58. string proposed_name = null;
  59. while (proposed_name == null || avoid_names.Contains(proposed_name))
  60. {
  61. if (!name_uid_map.ContainsKey(name))
  62. name_uid_map[name] = 0;
  63. if (zero_based)
  64. {
  65. int number = name_uid_map[name];
  66. if (number > 0)
  67. proposed_name = $"{name}_{number}";
  68. else
  69. proposed_name = name;
  70. name_uid_map[name] += 1;
  71. }
  72. else
  73. {
  74. name_uid_map[name] += 1;
  75. proposed_name = $"{name}_{name_uid_map[name]}";
  76. }
  77. }
  78. return proposed_name;
  79. }
  80. public static Dictionary<string, int> get_default_graph_uid_map()
  81. {
  82. var graph = ops.get_default_graph();
  83. Dictionary<string, int> name_uid_map = null;
  84. if (keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
  85. {
  86. name_uid_map = keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph];
  87. }
  88. else
  89. {
  90. name_uid_map = new Dictionary<string, int>();
  91. keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map;
  92. }
  93. return name_uid_map;
  94. }
  95. public static bool needs_keras_history(Tensors inputs)
  96. {
  97. if (inputs.Any(x => x.KerasHistory == null))
  98. return true;
  99. return false;
  100. }
  101. public static Layer[] create_keras_history(Tensors inputs)
  102. {
  103. var processed_ops = new List<Operation>();
  104. var created_layers = new List<Layer>();
  105. CreateKerasHistoryHelper(inputs, processed_ops, created_layers);
  106. return created_layers.ToArray();
  107. }
  108. public static void CreateKerasHistoryHelper(Tensors tensors, List<Operation> processed_ops, List<Layer> created_layers)
  109. {
  110. foreach (var tensor in tensors)
  111. {
  112. if (tensor.KerasHistory != null)
  113. continue;
  114. var op = tensor.op;
  115. if (!processed_ops.Contains(op))
  116. {
  117. var layer_inputs = new List<Tensor>();
  118. var constants = new Dictionary<int, NDArray>();
  119. foreach (var (i, op_input) in enumerate(op.inputs._inputs))
  120. {
  121. if (uses_keras_history(op_input))
  122. layer_inputs.Add(op_input);
  123. else
  124. {
  125. tf_with(ops.init_scope(), delegate
  126. {
  127. constants[i] = keras.backend.eval_in_eager_or_function(op_input);
  128. });
  129. }
  130. }
  131. // recursively
  132. CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
  133. var op_layer = GetLayer<ITensorFlowOpLayer>(new TensorFlowOpLayerArgs
  134. {
  135. NodeDef = op.node_def,
  136. Constants = constants,
  137. Name = op.name
  138. });
  139. created_layers.Add(op_layer);
  140. op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
  141. processed_ops.Add(op);
  142. }
  143. }
  144. }
  145. static Layer GetLayer<T>(LayerArgs args)
  146. {
  147. Layer layer = default;
  148. var assemble = Assembly.Load("TensorFlow.Keras.Layers");
  149. foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null))
  150. {
  151. layer = (Layer)Activator.CreateInstance(type, new object[] { args });
  152. }
  153. if (layer == null)
  154. throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}");
  155. return layer;
  156. }
  157. // recusive
  158. static bool uses_keras_history(Tensor op_input)
  159. {
  160. if (op_input.KerasHistory != null)
  161. return true;
  162. foreach (var input in op_input.op.inputs._inputs)
  163. if (uses_keras_history(input))
  164. return true;
  165. return false;
  166. }
  167. }
  168. }