You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tape.cs 1.3 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. namespace Tensorflow.Gradients
  5. {
  6. public class Tape : DisposableObject
  7. {
  8. public GradientTape tape { get; set; }
  9. public int nesting_id { get; set; }
  10. public Tape(bool persistent, bool watch_accessed_variables)
  11. {
  12. _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
  13. }
  14. public void watch(Tensor x)
  15. {
  16. c_api.TFE_TapeWatch(_handle, x, x.Id);
  17. }
  18. public static bool IsDtypeTrainable(DataType dtype)
  19. {
  20. switch (dtype)
  21. {
  22. case DataType.DtHalf:
  23. case DataType.DtBfloat16:
  24. case DataType.DtFloat:
  25. case DataType.DtDouble:
  26. case DataType.DtComplex64:
  27. case DataType.DtComplex128:
  28. case DataType.DtResource:
  29. case DataType.DtVariant:
  30. return true;
  31. default:
  32. return false;
  33. }
  34. }
  35. protected override void DisposeUnmanagedResources(IntPtr handle)
  36. {
  37. }
  38. public static implicit operator IntPtr(Tape tape)
  39. => tape._handle;
  40. }
  41. }