You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tape.cs 2.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Runtime.InteropServices;
  5. using System.Text;
  6. using Tensorflow.Eager;
  7. namespace Tensorflow.Gradients
  8. {
  9. public class Tape : DisposableObject
  10. {
  11. public int nesting_id { get; set; }
  12. public Tape(bool persistent, bool watch_accessed_variables)
  13. {
  14. _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
  15. }
  16. public void watch(EagerTensor x)
  17. {
  18. c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle);
  19. }
  20. public void pop_tape(Tape tape)
  21. {
  22. c_api.TFE_TapeSetRemove(tape);
  23. }
  24. public static void variable_accessed(ResourceVariable variable)
  25. {
  26. c_api.TFE_TapeVariableAccessed(variable);
  27. }
  28. public unsafe ResourceVariable[] watched_variables()
  29. {
  30. BindingArray result = c_api.TFE_TapeWatchedVariables(_handle);
  31. var variables = result.Data.Select(x =>
  32. {
  33. var tensor = c_api.ResourceVariable_Handle(x);
  34. return new ResourceVariable(x, tensor);
  35. }).ToArray();
  36. return variables;
  37. }
  38. public static bool IsDtypeTrainable(DataType dtype)
  39. {
  40. switch (dtype)
  41. {
  42. case DataType.DtHalf:
  43. case DataType.DtBfloat16:
  44. case DataType.DtFloat:
  45. case DataType.DtDouble:
  46. case DataType.DtComplex64:
  47. case DataType.DtComplex128:
  48. case DataType.DtResource:
  49. case DataType.DtVariant:
  50. return true;
  51. default:
  52. return false;
  53. }
  54. }
  55. protected override void DisposeUnmanagedResources(IntPtr handle)
  56. {
  57. }
  58. public static implicit operator IntPtr(Tape tape)
  59. => tape._handle;
  60. }
  61. }