using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; using Tensorflow.Eager; namespace Tensorflow.Gradients { public class Tape : DisposableObject { public int nesting_id { get; set; } public Tape(bool persistent, bool watch_accessed_variables) { _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables); } public void watch(EagerTensor x) { c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); } public void pop_tape(Tape tape) { c_api.TFE_TapeSetRemove(tape); } public static void variable_accessed(ResourceVariable variable) { c_api.TFE_TapeVariableAccessed(variable); } public unsafe ResourceVariable[] watched_variables() { BindingArray result = c_api.TFE_TapeWatchedVariables(_handle); var variables = result.Data.Select(x => { var tensor = c_api.ResourceVariable_Handle(x); return new ResourceVariable(x, tensor); }).ToArray(); return variables; } public static bool IsDtypeTrainable(DataType dtype) { switch (dtype) { case DataType.DtHalf: case DataType.DtBfloat16: case DataType.DtFloat: case DataType.DtDouble: case DataType.DtComplex64: case DataType.DtComplex128: case DataType.DtResource: case DataType.DtVariant: return true; default: return false; } } protected override void DisposeUnmanagedResources(IntPtr handle) { } public static implicit operator IntPtr(Tape tape) => tape._handle; } }