You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 12 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. using Tensorflow.NumPy;
  2. using System;
  3. using Tensorflow.Eager;
  4. using Tensorflow.Variables;
  5. using Tensorflow.Train;
  6. using static Tensorflow.Binding;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using Tensorflow.Checkpoint;
  10. using Tensorflow.Training.Saving.SavedModel;
  11. namespace Tensorflow
  12. {
  13. public class BaseResourceVariable : DisposableTrackableObject
  14. {
  15. protected string _name;
  16. public virtual string Name => _handle_name;
  17. public virtual string SharedName => _name;
  18. protected TF_DataType _dtype;
  19. public TF_DataType dtype => _dtype;
  20. protected string _handle_name;
  21. public string handle_name
  22. {
  23. get { return _handle_name; }
  24. set { _handle_name = value; }
  25. }
  26. protected string _unique_id;
  27. public string UniqueId => _unique_id;
  28. protected bool _in_graph_mode;
  29. internal bool InGraphMode => _in_graph_mode;
  30. protected bool _trainable;
  31. public bool Trainable => _trainable;
  32. protected Tensor _initial_value;
  33. public Operation initializer => initializer_op;
  34. protected Tensor _parent_op;
  35. public Tensor parent_op => _parent_op;
  36. /// <summary>
  37. /// Tensor handle
  38. /// </summary>
  39. protected Tensor handle;
  40. public Tensor Handle => handle;
  41. protected Tensor _graph_element;
  42. public Tensor GraphElement => _graph_element;
  43. protected Shape _shape;
  44. public Shape shape => _shape;
  45. protected Operation initializer_op;
  46. public Operation Initializer => initializer_op;
  47. public Operation Op => handle.op;
  48. public Graph Graph => handle.graph;
  49. public string Device => handle.Device;
  50. EagerResourceDeleter eager_resource_deleter;
  51. public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None;
  52. public BaseResourceVariable()
  53. {
  54. }
  55. public void __init__(bool trainable = true,
  56. Tensor handle = null,
  57. string name = null,
  58. string unique_id = null,
  59. string handle_name = null)
  60. {
  61. _trainable = trainable;
  62. _handle_name = handle_name + ":0";
  63. _unique_id = unique_id;
  64. this.handle = handle;
  65. _name = name;
  66. // After the handle has been created, set up a way to clean it up when
  67. // executing eagerly. We'll hold the only reference to the deleter, so that
  68. // when this object is garbage collected the deleter will be too. This
  69. // means ResourceVariables can be part of reference cycles without those
  70. // cycles being uncollectable.
  71. if (handle is EagerTensor)
  72. {
  73. _handle = handle.EagerTensorHandle.DangerousGetHandle();
  74. eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
  75. }
  76. else if(handle is null)
  77. {
  78. // TODO: fix this dangerous change.
  79. _handle = IntPtr.Zero;
  80. }
  81. else
  82. {
  83. _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle();
  84. }
  85. #if TRACK_TENSOR_LIFE
  86. print($"Created Resource 0x{_handle.ToString("x16")} {_name}");
  87. #endif
  88. }
  89. public Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  90. {
  91. if (value.GetType() == typeof(Tensor))
  92. {
  93. var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name);
  94. if (read_value)
  95. return assign;
  96. return assign.op;
  97. }
  98. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  99. var assign_op = gen_resource_variable_ops.assign_variable_op(
  100. handle, value_tensor, name: name);
  101. if (read_value)
  102. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  103. if (assign_op == null)
  104. return null;
  105. return assign_op;
  106. }
  107. public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
  108. {
  109. _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
  110. }
  111. void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
  112. int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
  113. {
  114. var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
  115. begin_mask: begin_mask,
  116. end_mask: end_mask,
  117. ellipsis_mask: ellipsis_mask,
  118. new_axis_mask: new_axis_mask,
  119. shrink_axis_mask: shrink_axis_mask);
  120. }
  121. public IVariableV1 assign_lazy_load(Tensor value, string name = null)
  122. {
  123. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  124. var assign_op = gen_resource_variable_ops.assign_variable_op(
  125. handle, value_tensor, name: name);
  126. var variable = _lazy_read(assign_op, value_tensor);
  127. return variable;
  128. }
  129. public Tensor value()
  130. => GraphElement ?? _read_variable_op();
  131. protected Tensor _read_variable_op()
  132. {
  133. variable_accessed(this);
  134. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  135. // _maybe_set_handle_data(_dtype, _handle, result);
  136. // have to set shape when converting to substituent placeholder
  137. if (result.shape.ndim == -1)
  138. {
  139. c_api.TF_GraphSetTensorShape(result.graph,
  140. result._as_tf_output(),
  141. shape.dims,
  142. shape.ndim,
  143. tf.Status);
  144. tf.Status.Check(true);
  145. }
  146. return result;
  147. }
  148. IVariableV1 _lazy_read(Operation op, Tensor value)
  149. {
  150. variable_accessed(this);
  151. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  152. }
  153. /// <summary>
  154. /// Records that `variable` was accessed for the tape and FuncGraph.
  155. /// </summary>
  156. void variable_accessed(BaseResourceVariable variable)
  157. {
  158. if (variable.Trainable)
  159. {
  160. foreach (var tape in tf.GetTapeSet())
  161. tape.VariableAccessed(variable as ResourceVariable);
  162. }
  163. }
  164. /// <summary>
  165. /// Constructs an op which reads the value of this variable.
  166. ///
  167. /// Should be used when there are multiple reads, or when it is desirable to
  168. /// read the value only after some condition is true.
  169. /// </summary>
  170. /// <returns></returns>
  171. protected Tensor read_value()
  172. {
  173. var value = tf_with(ops.name_scope("Read"), delegate
  174. {
  175. return _read_variable_op();
  176. });
  177. return array_ops.identity(value);
  178. }
  179. public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  180. {
  181. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  182. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  183. if (read_value)
  184. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  185. // return _lazy_read(assign_add_op);
  186. return assign_add_op;
  187. }
  188. public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  189. {
  190. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  191. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  192. if (read_value)
  193. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  194. // return _lazy_read(assign_add_op);
  195. return assign_sub_op;
  196. }
  197. public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
  198. {
  199. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  200. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  201. return _lazy_read(assign_sub_op, delta);
  202. }
  203. public override string ToString()
  204. {
  205. if (tf.Context.executing_eagerly())
  206. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value().numpy()}";
  207. else
  208. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  209. }
  210. public NDArray numpy() => read_value().numpy();
  211. protected override void DisposeUnmanagedResources(IntPtr handle)
  212. {
  213. #if TRACK_TENSOR_LIFE
  214. print($"Deleted Resource 0x{handle.ToString("x16")} {_name}");
  215. #endif
  216. }
  217. public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  218. {
  219. if (as_ref)
  220. return read_value().op.inputs[0];
  221. else
  222. return value();
  223. }
  224. public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options)
  225. {
  226. BaseResourceVariable new_variable;
  227. if (save_options.experimental_variable_policy.save_variable_devices())
  228. {
  229. Debug.Assert(this is ResourceVariable);
  230. new_variable = tf_with(ops.device(this.Device), _ =>
  231. {
  232. return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
  233. });
  234. }
  235. else
  236. {
  237. new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
  238. }
  239. Dictionary<Trackable, Trackable> obj_map = new();
  240. Dictionary<Tensor, Tensor> resource_map = new();
  241. obj_map[this] = new_variable;
  242. resource_map[this.handle] = new_variable.handle;
  243. return (obj_map, resource_map);
  244. }
  245. /// <summary>
  246. /// Writes additional information of the variable into the SavedObject proto.
  247. /// ubclasses of ResourceVariables could choose to override this method to
  248. /// customize extra information to provide when saving a SavedModel.
  249. /// </summary>
  250. /// <param name="proto"></param>
  251. /// <param name="options"></param>
  252. public virtual void write_object_proto(SavedObject proto, SaveOptions options)
  253. {
  254. resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
  255. }
  256. public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
  257. {
  258. var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
  259. res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this;
  260. return res;
  261. }
  262. public Tensor is_initialized(string name = null)
  263. {
  264. return gen_resource_variable_ops.var_is_initialized_op(this.handle, name);
  265. }
  266. public Tensor read_value_no_copy()
  267. {
  268. Tensor value = null;
  269. tf_with(ops.name_scope("Read"), _ =>
  270. {
  271. // TODO: `no_copy = true`.
  272. value = _read_variable_op();
  273. });
  274. return array_ops.identity(value);
  275. }
  276. }
  277. }