You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tape.ComputeGradient.cs 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Util;
  6. using static Tensorflow.tensorflow;
  7. namespace Tensorflow.Gradients
  8. {
  9. public partial class Tape
  10. {
  11. int kMinAggregateCount = 4;
  12. int kMinAggregateBytes = 128 * 1024 * 1024;
  13. public Tensor[] ComputeGradient(long[] target_tensor_ids,
  14. long[] source_tensor_ids,
  15. UnorderedMap<long, TapeTensor> sources_that_are_targets,
  16. Tensor[] output_gradients)
  17. {
  18. var result = new List<Tensor>(source_tensor_ids.Length);
  19. var sources_set = new UnorderedSet<long>(source_tensor_ids);
  20. var gradients_size = new UnorderedMap<long, long>();
  21. var state = PrepareBackprop(
  22. target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_);
  23. var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
  24. var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
  25. output_gradients,
  26. tensor_tape_,
  27. state.op_tape);
  28. while (!op_stack.empty())
  29. {
  30. var op = op_stack.Dequeue();
  31. if (!state.op_tape.find(op, out var trace))
  32. continue;
  33. // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}");
  34. state.op_tape.erase(op);
  35. var out_gradients = new List<Tensor>(trace.output_tensor_info.Length);
  36. var unneeded_gradients = new List<long>();
  37. for (int i = 0; i < trace.input_tensor_id.Length; i++)
  38. {
  39. var in_tensor_id = trace.input_tensor_id[i];
  40. if (!tensor_tape_.find(in_tensor_id) &&
  41. !sources_set.find(in_tensor_id))
  42. unneeded_gradients.Add(i);
  43. }
  44. bool any_gradient_nonzero = false;
  45. var zero_indices = new List<int>();
  46. for (int i = 0; i < trace.output_tensor_info.Length; ++i)
  47. {
  48. var id = trace.output_tensor_info[i].GetID();
  49. if (!gradients.find(id, out var grad_it))
  50. {
  51. if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) &&
  52. func_name_it.find(i))
  53. {
  54. out_gradients.Add(null);
  55. }
  56. else
  57. {
  58. out_gradients.Add(null);
  59. zero_indices.Add(i);
  60. }
  61. }
  62. else
  63. {
  64. any_gradient_nonzero = true;
  65. var new_gradients = grad_it.Count == 1 ?
  66. grad_it[0] :
  67. gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients
  68. if (!sources_set.find(id))
  69. gradients.Remove(id);
  70. else
  71. {
  72. grad_it.Clear();
  73. grad_it.Add(new_gradients);
  74. // vspace.MarkAsResult(new_gradients);
  75. }
  76. out_gradients.Add(new_gradients);
  77. }
  78. }
  79. Tensor[] in_gradients;
  80. if (any_gradient_nonzero)
  81. {
  82. foreach (var i in zero_indices)
  83. out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
  84. in_gradients = CallBackwardFunction(trace.backward_function,
  85. unneeded_gradients,
  86. out_gradients);
  87. if (in_gradients.Count() != trace.input_tensor_id.Count())
  88. throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
  89. if (!persistent_)
  90. {
  91. // trace.backward_function_deleter(trace.backward_function);
  92. }
  93. }
  94. else
  95. {
  96. in_gradients = new Tensor[trace.input_tensor_id.Length];
  97. }
  98. for (int i = 0; i < in_gradients.Length; ++i)
  99. {
  100. var id = trace.input_tensor_id[i];
  101. if (in_gradients[i] != null)
  102. {
  103. var unaggregated_grads = gradients[id];
  104. unaggregated_grads.Add(in_gradients[i]);
  105. if (unaggregated_grads.Count > kMinAggregateCount)
  106. {
  107. if (!gradients_size.find(id, out var size))
  108. {
  109. size = (long)unaggregated_grads[0].size;
  110. gradients_size.emplace(id, size);
  111. }
  112. if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
  113. {
  114. throw new NotImplementedException("");
  115. }
  116. }
  117. }
  118. if (!state.tensor_usage_counts.find(id))
  119. continue;
  120. state.tensor_usage_counts[id]--;
  121. if (state.tensor_usage_counts[id] > 0)
  122. continue;
  123. if (!tensor_tape_.find(id, out var tape_it))
  124. {
  125. if (gradients.find(id, out var grad_it))
  126. {
  127. // foreach (var g in grad_it)
  128. // DeleteGradient(g);
  129. gradients.erase(id);
  130. }
  131. continue;
  132. }
  133. var op_id = tape_it;
  134. if (op_id == -1)
  135. continue;
  136. if(state.op_missing_tensor.find(op_id, out var missing_it))
  137. {
  138. state.op_missing_tensor[op_id]--;
  139. if (state.op_missing_tensor[op_id] == 0)
  140. op_stack.Enqueue(op_id);
  141. }
  142. }
  143. }
  144. if (state.op_tape.Count > 0)
  145. throw new RuntimeError("Invalid tape state.");
  146. var used_gradient_ids = new List<long>(source_tensor_ids.Length);
  147. foreach (var id in source_tensor_ids)
  148. {
  149. if (!gradients.find(id, out var grad_it))
  150. result.Add(null);
  151. else
  152. {
  153. if(grad_it.Count > 1)
  154. {
  155. var grad = gen_math_ops.add_n(grad_it.ToArray());
  156. grad_it.Clear();
  157. grad_it.Add(grad);
  158. }
  159. result.Add(grad_it[0]);
  160. used_gradient_ids.Add(id);
  161. }
  162. }
  163. /*foreach(var grad_pair in gradients)
  164. {
  165. if(!used_gradient_ids.Contains(grad_pair.Key))
  166. {
  167. foreach(var g in grad_pair.Value)
  168. {
  169. vspace.DeleteGradient(g);
  170. }
  171. }
  172. }*/
  173. return result.ToArray();
  174. }
  175. UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap()
  176. {
  177. var m = new UnorderedMap<string, UnorderedSet<int>>();
  178. m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
  179. m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
  180. m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
  181. return m;
  182. }
  183. UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
  184. UnorderedMap<long, TapeTensor> sources_that_are_targets,
  185. Tensor[] output_gradients,
  186. TensorTape tensor_tape,
  187. OpTape<BackwardFunction, TapeTensor> op_tape)
  188. {
  189. var result = new UnorderedMapEnumerable<long, List<Tensor>>();
  190. for (int i = 0; i < target_tensor_ids.Length; ++i)
  191. {
  192. var id = target_tensor_ids[i];
  193. if (output_gradients.Length == 0 || output_gradients[i] == null)
  194. {
  195. if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1)
  196. {
  197. if (!op_tape.find(tensor_tape[id], out var op_it))
  198. throw new RuntimeError("Internal state of the gradient tape is invalid: " +
  199. "failed to find operation producing a tensor");
  200. bool found = false;
  201. for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
  202. {
  203. if (op_it.output_tensor_info[j].GetID() == id)
  204. {
  205. found = true;
  206. var ones = op_it.output_tensor_info[j].OnesLike();
  207. result[id].Add(ones);
  208. break;
  209. }
  210. }
  211. if (!found)
  212. {
  213. throw new ValueError("Internal state of the gradient tape is invalid: " +
  214. "none of operations outputs match expected tensor");
  215. }
  216. }
  217. else
  218. {
  219. if (sources_that_are_targets.find(id, out var source_tensor))
  220. result[id].Add(source_tensor.OnesLike());
  221. }
  222. }
  223. else
  224. {
  225. result[id].Add(output_gradients[i]);
  226. }
  227. }
  228. return result;
  229. }
  230. Queue<long> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape,
  231. UnorderedMap<long, long> op_missing_tensor)
  232. {
  233. var result = new Queue<long>();
  234. foreach(var op_entry in op_tape)
  235. {
  236. if (!op_missing_tensor.find(op_entry.Key))
  237. result.Enqueue(op_entry.Key);
  238. }
  239. return result;
  240. }
  241. }
  242. }