You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 8.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. using Tensorflow.NumPy;
  2. using System;
  3. using Tensorflow.Eager;
  4. using Tensorflow.Variables;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow
  7. {
  8. public class BaseResourceVariable : DisposableObject
  9. {
  10. protected string _name;
  11. public virtual string Name => _handle_name;
  12. protected TF_DataType _dtype;
  13. public TF_DataType dtype => _dtype;
  14. protected string _handle_name;
  15. protected string handle_name => _handle_name;
  16. protected string _unique_id;
  17. public string UniqueId => _unique_id;
  18. protected bool _in_graph_mode;
  19. protected bool _trainable;
  20. public bool trainable => _trainable;
  21. protected Tensor _initial_value;
  22. public Operation initializer => initializer_op;
  23. protected Tensor _parent_op;
  24. public Tensor parent_op => _parent_op;
  25. /// <summary>
  26. /// Tensor handle
  27. /// </summary>
  28. protected Tensor handle;
  29. public Tensor Handle => handle;
  30. protected Tensor _graph_element;
  31. public Tensor GraphElement => _graph_element;
  32. protected Shape _shape;
  33. public Shape shape => _shape;
  34. protected Operation initializer_op;
  35. public Operation Initializer => initializer_op;
  36. public Operation Op => handle.op;
  37. public Graph Graph => handle.graph;
  38. public string Device => handle.Device;
  39. EagerResourceDeleter eager_resource_deleter;
  40. public BaseResourceVariable()
  41. {
  42. }
  43. public void __init__(bool trainable = true,
  44. Tensor handle = null,
  45. string name = null,
  46. string unique_id = null,
  47. string handle_name = null)
  48. {
  49. _trainable = trainable;
  50. _handle_name = handle_name + ":0";
  51. _unique_id = unique_id;
  52. this.handle = handle;
  53. _name = name;
  54. // After the handle has been created, set up a way to clean it up when
  55. // executing eagerly. We'll hold the only reference to the deleter, so that
  56. // when this object is garbage collected the deleter will be too. This
  57. // means ResourceVariables can be part of reference cycles without those
  58. // cycles being uncollectable.
  59. if (!handle.IsCreatedInGraphMode)
  60. {
  61. _handle = handle.EagerTensorHandle.DangerousGetHandle();
  62. eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
  63. }
  64. else
  65. {
  66. _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle();
  67. }
  68. #if TRACK_TENSOR_LIFE
  69. print($"Created Resource 0x{_handle.ToString("x16")} {_name}");
  70. #endif
  71. }
  72. public Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  73. {
  74. if (value.GetType() == typeof(Tensor))
  75. {
  76. var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name);
  77. if (read_value)
  78. return assign;
  79. return assign.op;
  80. }
  81. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  82. var assign_op = gen_resource_variable_ops.assign_variable_op(
  83. handle, value_tensor, name: name);
  84. if (read_value)
  85. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  86. if (assign_op == null)
  87. return null;
  88. return assign_op;
  89. }
  90. public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
  91. {
  92. _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
  93. }
  94. void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
  95. int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
  96. {
  97. var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
  98. begin_mask: begin_mask,
  99. end_mask: end_mask,
  100. ellipsis_mask: ellipsis_mask,
  101. new_axis_mask: new_axis_mask,
  102. shrink_axis_mask: shrink_axis_mask);
  103. }
  104. public IVariableV1 assign_lazy_load(Tensor value, string name = null)
  105. {
  106. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  107. var assign_op = gen_resource_variable_ops.assign_variable_op(
  108. handle, value_tensor, name: name);
  109. var variable = _lazy_read(assign_op, value_tensor);
  110. return variable;
  111. }
  112. public Tensor value()
  113. => GraphElement ?? _read_variable_op();
  114. protected Tensor _read_variable_op()
  115. {
  116. variable_accessed(this);
  117. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  118. // _maybe_set_handle_data(_dtype, _handle, result);
  119. // have to set shape when converting to substituent placeholder
  120. if (result.shape.ndim == -1)
  121. {
  122. c_api.TF_GraphSetTensorShape(result.graph,
  123. result._as_tf_output(),
  124. shape.dims,
  125. shape.ndim,
  126. tf.Status.Handle);
  127. tf.Status.Check(true);
  128. }
  129. return result;
  130. }
  131. IVariableV1 _lazy_read(Operation op, Tensor value)
  132. {
  133. variable_accessed(this);
  134. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  135. }
  136. /// <summary>
  137. /// Records that `variable` was accessed for the tape and FuncGraph.
  138. /// </summary>
  139. void variable_accessed(BaseResourceVariable variable)
  140. {
  141. if (variable.trainable)
  142. {
  143. foreach (var tape in tf.GetTapeSet())
  144. tape.VariableAccessed(variable as ResourceVariable);
  145. }
  146. }
  147. /// <summary>
  148. /// Constructs an op which reads the value of this variable.
  149. ///
  150. /// Should be used when there are multiple reads, or when it is desirable to
  151. /// read the value only after some condition is true.
  152. /// </summary>
  153. /// <returns></returns>
  154. protected Tensor read_value()
  155. {
  156. var value = tf_with(ops.name_scope("Read"), delegate
  157. {
  158. return _read_variable_op();
  159. });
  160. return array_ops.identity(value);
  161. }
  162. public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  163. {
  164. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  165. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  166. if (read_value)
  167. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  168. // return _lazy_read(assign_add_op);
  169. return assign_add_op;
  170. }
  171. public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  172. {
  173. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  174. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  175. if (read_value)
  176. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  177. // return _lazy_read(assign_add_op);
  178. return assign_sub_op;
  179. }
  180. public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
  181. {
  182. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  183. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  184. return _lazy_read(assign_sub_op, delta);
  185. }
  186. public override string ToString()
  187. {
  188. if (tf.Context.executing_eagerly())
  189. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value()}";
  190. else
  191. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  192. }
  193. public NDArray numpy() => read_value().numpy();
  194. protected override void DisposeUnmanagedResources(IntPtr handle)
  195. {
  196. #if TRACK_TENSOR_LIFE
  197. print($"Deleted Resource 0x{handle.ToString("x16")} {_name}");
  198. #endif
  199. }
  200. public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  201. {
  202. if (as_ref)
  203. return read_value().op.inputs[0];
  204. else
  205. return value();
  206. }
  207. }
  208. }