You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 8.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. using NumSharp;
  2. using System;
  3. using Tensorflow.Eager;
  4. using Tensorflow.Variables;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow
  7. {
  8. public class BaseResourceVariable : DisposableObject
  9. {
  10. protected string _name;
  11. public virtual string Name => _handle_name;
  12. protected TF_DataType _dtype;
  13. public TF_DataType dtype => _dtype;
  14. protected string _handle_name;
  15. protected string handle_name => _handle_name;
  16. protected string _unique_id;
  17. public string UniqueId => _unique_id;
  18. protected bool _in_graph_mode;
  19. protected bool _trainable;
  20. public bool trainable => _trainable;
  21. protected Tensor _initial_value;
  22. public Operation initializer => initializer_op;
  23. protected Tensor _parent_op;
  24. public Tensor parent_op => _parent_op;
  25. /// <summary>
  26. /// Tensor handle
  27. /// </summary>
  28. protected Tensor handle;
  29. public Tensor Handle => handle;
  30. protected Tensor _graph_element;
  31. public Tensor GraphElement => _graph_element;
  32. protected TensorShape _shape;
  33. public TensorShape shape => _shape;
  34. protected Operation initializer_op;
  35. public Operation Initializer => initializer_op;
  36. public Operation Op => handle.op;
  37. public Graph Graph => handle.graph;
  38. public string Device => handle.Device;
  39. EagerResourceDeleter eager_resource_deleter;
  40. public BaseResourceVariable()
  41. {
  42. }
  43. public void __init__(bool trainable = true,
  44. Tensor handle = null,
  45. string name = null,
  46. string unique_id = null,
  47. string handle_name = null)
  48. {
  49. _trainable = trainable;
  50. _handle_name = handle_name + ":0";
  51. _unique_id = unique_id;
  52. this.handle = handle;
  53. _name = name;
  54. // After the handle has been created, set up a way to clean it up when
  55. // executing eagerly. We'll hold the only reference to the deleter, so that
  56. // when this object is garbage collected the deleter will be too. This
  57. // means ResourceVariables can be part of reference cycles without those
  58. // cycles being uncollectable.
  59. if (handle.IsEagerTensor)
  60. {
  61. _handle = handle.EagerTensorHandle.DangerousGetHandle();
  62. eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
  63. }
  64. else
  65. {
  66. _handle = handle;
  67. }
  68. #if TRACK_TENSOR_LIFE
  69. print($"Created Resource 0x{_handle.ToString("x16")} {_name}");
  70. #endif
  71. }
  72. public Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  73. {
  74. if (value.GetType() == typeof(Tensor))
  75. {
  76. var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name);
  77. if (read_value)
  78. return assign;
  79. return assign.op;
  80. }
  81. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  82. var assign_op = gen_resource_variable_ops.assign_variable_op(
  83. handle, value_tensor, name: name);
  84. if (read_value)
  85. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  86. return assign_op;
  87. }
  88. public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
  89. {
  90. _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
  91. }
  92. void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
  93. int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
  94. {
  95. var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
  96. begin_mask: begin_mask,
  97. end_mask: end_mask,
  98. ellipsis_mask: ellipsis_mask,
  99. new_axis_mask: new_axis_mask,
  100. shrink_axis_mask: shrink_axis_mask);
  101. }
  102. public IVariableV1 assign_lazy_load(Tensor value, string name = null)
  103. {
  104. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  105. var assign_op = gen_resource_variable_ops.assign_variable_op(
  106. handle, value_tensor, name: name);
  107. var variable = _lazy_read(assign_op, value_tensor);
  108. return variable;
  109. }
  110. public Tensor value()
  111. => GraphElement ?? _read_variable_op();
  112. protected Tensor _read_variable_op()
  113. {
  114. variable_accessed(this);
  115. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  116. // _maybe_set_handle_data(_dtype, _handle, result);
  117. // have to set shape when converting to substituent placeholder
  118. if (result.TensorShape.ndim == -1)
  119. {
  120. c_api.TF_GraphSetTensorShape(result.graph,
  121. result._as_tf_output(),
  122. shape.as_list_long(),
  123. shape.ndim,
  124. tf.Status.Handle);
  125. tf.Status.Check(true);
  126. }
  127. return result;
  128. }
  129. IVariableV1 _lazy_read(Operation op, Tensor value)
  130. {
  131. variable_accessed(this);
  132. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  133. }
  134. /// <summary>
  135. /// Records that `variable` was accessed for the tape and FuncGraph.
  136. /// </summary>
  137. void variable_accessed(BaseResourceVariable variable)
  138. {
  139. if (variable.trainable)
  140. {
  141. foreach (var tape in tf.GetTapeSet())
  142. tape.VariableAccessed(variable as ResourceVariable);
  143. }
  144. }
  145. /// <summary>
  146. /// Constructs an op which reads the value of this variable.
  147. ///
  148. /// Should be used when there are multiple reads, or when it is desirable to
  149. /// read the value only after some condition is true.
  150. /// </summary>
  151. /// <returns></returns>
  152. protected Tensor read_value()
  153. {
  154. var value = tf_with(ops.name_scope("Read"), delegate
  155. {
  156. return _read_variable_op();
  157. });
  158. return array_ops.identity(value);
  159. }
  160. public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  161. {
  162. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  163. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  164. if (read_value)
  165. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  166. // return _lazy_read(assign_add_op);
  167. return assign_add_op;
  168. }
  169. public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  170. {
  171. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  172. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  173. if (read_value)
  174. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  175. // return _lazy_read(assign_add_op);
  176. return assign_sub_op;
  177. }
  178. public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
  179. {
  180. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  181. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  182. return _lazy_read(assign_sub_op, delta);
  183. }
  184. public override string ToString()
  185. {
  186. if (tf.Context.executing_eagerly())
  187. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={tensor_util.to_numpy_string(read_value())}";
  188. else
  189. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  190. }
  191. public NDArray numpy() => read_value().numpy();
  192. protected override void DisposeUnmanagedResources(IntPtr handle)
  193. {
  194. #if TRACK_TENSOR_LIFE
  195. print($"Deleted Resource 0x{handle.ToString("x16")} {_name}");
  196. #endif
  197. }
  198. public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  199. {
  200. if (as_ref)
  201. return read_value().op.inputs[0];
  202. else
  203. return value();
  204. }
  205. }
  206. }