You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tape.PrepareBackprop.cs 2.1 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. using System.Collections.Generic;
  2. using Tensorflow.Util;
  3. namespace Tensorflow.Gradients
  4. {
  5. public partial class Tape
  6. {
  7. public BackpropInitialState PrepareBackprop(long[] target,
  8. TensorTape tensor_tape,
  9. OpTape op_tape,
  10. UnorderedSet<long> sources_set,
  11. bool persistent_tape)
  12. {
  13. Stack<long> tensor_stack = new Stack<long>();
  14. foreach(var t in target)
  15. {
  16. tensor_stack.Push(t);
  17. }
  18. BackpropInitialState result = new BackpropInitialState();
  19. while(tensor_stack.Count > 0)
  20. {
  21. long tensor_id = tensor_stack.Pop();
  22. if(!tensor_tape.TryGetValue(tensor_id, out var op_id))
  23. {
  24. continue;
  25. }
  26. if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it)
  27. || result.op_tape.find(op_id))
  28. {
  29. continue;
  30. }
  31. result.op_tape.emplace(op_id, op_it);
  32. foreach(var it in op_it.input_tensor_id)
  33. {
  34. if(result.tensor_usage_counts.find(it))
  35. {
  36. result.tensor_usage_counts[it]++;
  37. }
  38. else
  39. {
  40. result.tensor_usage_counts[it] = 1;
  41. if (tensor_tape.find(it))
  42. {
  43. tensor_stack.Push(it);
  44. }
  45. }
  46. }
  47. if (!persistent_tape)
  48. {
  49. op_tape.erase(op_id);
  50. }
  51. }
  52. foreach(var pair in result.tensor_usage_counts)
  53. {
  54. if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1)
  55. {
  56. result.op_missing_tensor[it]++;
  57. }
  58. }
  59. if (!persistent_tape)
  60. {
  61. op_tape.Clear();
  62. }
  63. return result;
  64. }
  65. }
  66. }