You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tape.cs 2.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. using Tensorflow.Eager;
  6. namespace Tensorflow.Gradients
  7. {
  8. public class Tape : DisposableObject
  9. {
  10. public int nesting_id { get; set; }
  11. public Tape(bool persistent, bool watch_accessed_variables)
  12. {
  13. _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
  14. }
  15. public void watch(EagerTensor x)
  16. {
  17. c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle);
  18. }
  19. public void pop_tape(Tape tape)
  20. {
  21. c_api.TFE_TapeSetRemove(tape);
  22. }
  23. public static void variable_accessed(ResourceVariable variable)
  24. {
  25. c_api.TFE_TapeVariableAccessed(variable);
  26. }
  27. public unsafe ResourceVariable[] watched_variables()
  28. {
  29. BindingArray result = c_api.TFE_TapeWatchedVariables(_handle);
  30. var variables = new ResourceVariable[result.length];
  31. for (int i = 0; i < result.length; i++)
  32. {
  33. var handle = *((IntPtr*)result.array + i);
  34. var tensor = c_api.ResourceVariable_Handle(handle);
  35. variables[i] = new ResourceVariable(handle, tensor);
  36. }
  37. return variables;
  38. }
  39. public static bool IsDtypeTrainable(DataType dtype)
  40. {
  41. switch (dtype)
  42. {
  43. case DataType.DtHalf:
  44. case DataType.DtBfloat16:
  45. case DataType.DtFloat:
  46. case DataType.DtDouble:
  47. case DataType.DtComplex64:
  48. case DataType.DtComplex128:
  49. case DataType.DtResource:
  50. case DataType.DtVariant:
  51. return true;
  52. default:
  53. return false;
  54. }
  55. }
  56. protected override void DisposeUnmanagedResources(IntPtr handle)
  57. {
  58. }
  59. public static implicit operator IntPtr(Tape tape)
  60. => tape._handle;
  61. }
  62. }