|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- using System;
- using System.Collections.Generic;
- using System.Text;
-
- namespace Tensorflow.Gradients
- {
- public class Tape : DisposableObject
- {
- public GradientTape tape { get; set; }
- public int nesting_id { get; set; }
-
- public Tape(bool persistent, bool watch_accessed_variables)
- {
- _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
- }
-
- public void watch(Tensor x)
- {
- c_api.TFE_TapeWatch(_handle, x, x.Id);
- }
-
- public static bool IsDtypeTrainable(DataType dtype)
- {
- switch (dtype)
- {
- case DataType.DtHalf:
- case DataType.DtBfloat16:
- case DataType.DtFloat:
- case DataType.DtDouble:
- case DataType.DtComplex64:
- case DataType.DtComplex128:
- case DataType.DtResource:
- case DataType.DtVariant:
- return true;
- default:
- return false;
- }
- }
-
- protected override void DisposeUnmanagedResources(IntPtr handle)
- {
- }
-
- public static implicit operator IntPtr(Tape tape)
- => tape._handle;
- }
- }
|