You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 12 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. using Tensorflow.NumPy;
  2. using System;
  3. using Tensorflow.Eager;
  4. using Tensorflow.Variables;
  5. using Tensorflow.Train;
  6. using static Tensorflow.Binding;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using Tensorflow.Checkpoint;
  10. using Tensorflow.Training.Saving.SavedModel;
  11. using OneOf;
  12. using Tensorflow.Graphs;
  13. namespace Tensorflow
  14. {
  15. public class BaseResourceVariable : DisposableTrackableObject
  16. {
  17. protected string _name;
  18. public virtual string Name => _handle_name;
  19. public virtual string SharedName => _name;
  20. protected TF_DataType _dtype;
  21. public TF_DataType dtype => _dtype;
  22. protected string _handle_name;
  23. public string handle_name
  24. {
  25. get { return _handle_name; }
  26. set { _handle_name = value; }
  27. }
  28. protected string _unique_id;
  29. public string UniqueId => _unique_id;
  30. protected bool _in_graph_mode;
  31. internal bool InGraphMode => _in_graph_mode;
  32. protected bool _trainable;
  33. public bool Trainable => _trainable;
  34. protected Tensor _initial_value;
  35. public Operation initializer => initializer_op;
  36. protected Tensor _parent_op;
  37. public Tensor parent_op => _parent_op;
  38. /// <summary>
  39. /// Tensor handle
  40. /// </summary>
  41. protected Tensor handle;
  42. public Tensor Handle => handle;
  43. protected Tensor _graph_element;
  44. public Tensor GraphElement => _graph_element;
  45. protected Shape _shape;
  46. public Shape shape => _shape;
  47. protected Operation initializer_op;
  48. public Operation Initializer => initializer_op;
  49. public Operation Op => handle.op;
  50. public Graph Graph => handle.graph;
  51. public string Device => handle.Device;
  52. EagerResourceDeleter eager_resource_deleter;
  53. public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None;
  54. public BaseResourceVariable()
  55. {
  56. }
  57. public void __init__(bool trainable = true,
  58. Shape shape = null,
  59. TF_DataType dtype = TF_DataType.DtInvalid,
  60. Tensor handle = null,
  61. string name = null,
  62. string unique_id = null,
  63. string handle_name = null)
  64. {
  65. _trainable = trainable;
  66. _handle_name = handle_name + ":0";
  67. _unique_id = unique_id;
  68. this.handle = handle;
  69. _name = name;
  70. if(shape is not null)
  71. {
  72. _shape = shape;
  73. }
  74. if(dtype != TF_DataType.DtInvalid)
  75. {
  76. _dtype = dtype;
  77. }
  78. // After the handle has been created, set up a way to clean it up when
  79. // executing eagerly. We'll hold the only reference to the deleter, so that
  80. // when this object is garbage collected the deleter will be too. This
  81. // means ResourceVariables can be part of reference cycles without those
  82. // cycles being uncollectable.
  83. if (handle is EagerTensor)
  84. {
  85. _handle = handle.EagerTensorHandle.DangerousGetHandle();
  86. eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
  87. }
  88. else if(handle is null)
  89. {
  90. // TODO: fix this dangerous change.
  91. _handle = IntPtr.Zero;
  92. }
  93. else
  94. {
  95. _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle();
  96. }
  97. #if TRACK_TENSOR_LIFE
  98. print($"Created Resource 0x{_handle.ToString("x16")} {_name}");
  99. #endif
  100. }
  101. public Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  102. {
  103. if (value.GetType() == typeof(Tensor))
  104. {
  105. var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name);
  106. if (read_value)
  107. return assign;
  108. return assign.op;
  109. }
  110. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  111. var assign_op = gen_resource_variable_ops.assign_variable_op(
  112. handle, value_tensor, name: name);
  113. if (read_value)
  114. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  115. if (assign_op == null)
  116. return null;
  117. return assign_op;
  118. }
  119. public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
  120. {
  121. _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
  122. }
  123. void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
  124. int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
  125. {
  126. var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
  127. begin_mask: begin_mask,
  128. end_mask: end_mask,
  129. ellipsis_mask: ellipsis_mask,
  130. new_axis_mask: new_axis_mask,
  131. shrink_axis_mask: shrink_axis_mask);
  132. }
  133. public IVariableV1 assign_lazy_load(Tensor value, string name = null)
  134. {
  135. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  136. var assign_op = gen_resource_variable_ops.assign_variable_op(
  137. handle, value_tensor, name: name);
  138. var variable = _lazy_read(assign_op, value_tensor);
  139. return variable;
  140. }
  141. public Tensor value()
  142. => GraphElement ?? _read_variable_op();
  143. protected Tensor _read_variable_op()
  144. {
  145. variable_accessed(this);
  146. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  147. resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);
  148. // have to set shape when converting to substituent placeholder
  149. if (result.shape.ndim == -1)
  150. {
  151. c_api.TF_GraphSetTensorShape(result.graph,
  152. result._as_tf_output(),
  153. shape.dims,
  154. shape.ndim,
  155. tf.Status);
  156. tf.Status.Check(true);
  157. }
  158. return result;
  159. }
  160. IVariableV1 _lazy_read(Operation op, Tensor value)
  161. {
  162. variable_accessed(this);
  163. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  164. }
  165. /// <summary>
  166. /// Records that `variable` was accessed for the tape and FuncGraph.
  167. /// </summary>
  168. void variable_accessed(BaseResourceVariable variable)
  169. {
  170. if(ops.get_default_graph() is FuncGraph func_graph)
  171. {
  172. func_graph.watch_variable(variable as IVariableV1);
  173. }
  174. if (variable.Trainable)
  175. {
  176. foreach (var tape in tf.GetTapeSet())
  177. tape.VariableAccessed(variable as ResourceVariable);
  178. }
  179. }
  180. /// <summary>
  181. /// Constructs an op which reads the value of this variable.
  182. ///
  183. /// Should be used when there are multiple reads, or when it is desirable to
  184. /// read the value only after some condition is true.
  185. /// </summary>
  186. /// <returns></returns>
  187. protected Tensor read_value()
  188. {
  189. var value = tf_with(ops.name_scope("Read"), delegate
  190. {
  191. return _read_variable_op();
  192. });
  193. return array_ops.identity(value);
  194. }
  195. public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  196. {
  197. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  198. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  199. if (read_value)
  200. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  201. // return _lazy_read(assign_add_op);
  202. return assign_add_op;
  203. }
  204. public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  205. {
  206. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  207. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  208. if (read_value)
  209. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  210. // return _lazy_read(assign_add_op);
  211. return assign_sub_op;
  212. }
  213. public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
  214. {
  215. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  216. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  217. return _lazy_read(assign_sub_op, delta);
  218. }
  219. public override string ToString()
  220. {
  221. if (tf.Context.executing_eagerly())
  222. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value().numpy()}";
  223. else
  224. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  225. }
  226. public NDArray numpy() => read_value().numpy();
  227. protected override void DisposeUnmanagedResources(IntPtr handle)
  228. {
  229. #if TRACK_TENSOR_LIFE
  230. print($"Deleted Resource 0x{handle.ToString("x16")} {_name}");
  231. #endif
  232. }
  233. public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  234. {
  235. if (as_ref)
  236. return read_value().op.inputs[0];
  237. else
  238. return value();
  239. }
  240. public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options)
  241. {
  242. BaseResourceVariable new_variable;
  243. if (save_options.experimental_variable_policy.save_variable_devices())
  244. {
  245. Debug.Assert(this is ResourceVariable);
  246. new_variable = tf_with(ops.device(this.Device), _ =>
  247. {
  248. return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
  249. });
  250. }
  251. else
  252. {
  253. new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
  254. }
  255. Dictionary<Trackable, Trackable> obj_map = new();
  256. Dictionary<Tensor, Tensor> resource_map = new();
  257. obj_map[this] = new_variable;
  258. resource_map[this.handle] = new_variable.handle;
  259. return (obj_map, resource_map);
  260. }
  261. /// <summary>
  262. /// Writes additional information of the variable into the SavedObject proto.
  263. /// ubclasses of ResourceVariables could choose to override this method to
  264. /// customize extra information to provide when saving a SavedModel.
  265. /// </summary>
  266. /// <param name="proto"></param>
  267. /// <param name="options"></param>
  268. public virtual void write_object_proto(SavedObject proto, SaveOptions options)
  269. {
  270. resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
  271. }
  272. public override IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
  273. {
  274. var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>();
  275. res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this;
  276. return res;
  277. }
  278. public Tensor is_initialized(string name = null)
  279. {
  280. return gen_resource_variable_ops.var_is_initialized_op(this.handle, name);
  281. }
  282. public Tensor read_value_no_copy()
  283. {
  284. Tensor value = null;
  285. tf_with(ops.name_scope("Read"), _ =>
  286. {
  287. // TODO: `no_copy = true`.
  288. value = _read_variable_op();
  289. });
  290. return array_ops.identity(value);
  291. }
  292. }
  293. }