You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RefVariable.cs 16 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using Google.Protobuf;
  14. using System;
  15. using System.Collections.Generic;
  16. using System.Linq;
  17. using static Tensorflow.Binding;
  18. namespace Tensorflow
  19. {
  20. public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable>
  21. {
  22. public bool _in_graph_mode = true;
  23. public Tensor _initial_value;
  24. public string _graph_key;
  25. public bool _trainable;
  26. public Tensor _snapshot;
  27. public bool _save_slice_info;
  28. private Operation _initializer_op;
  29. public override Operation initializer => _initializer_op;
  30. public override Operation op => _variable.op;
  31. public TF_DataType dtype => _variable.dtype;
  32. public TensorShape shape => tensor_util.to_shape(_variable.shape);
  33. public override string name => _variable.name;
  34. public Tensor eval() => _variable;
  35. public RefVariable(object initial_value = null,
  36. bool trainable = true,
  37. List<string> collections = null,
  38. bool validate_shape = true,
  39. string caching_device = "",
  40. string name = null,
  41. VariableDef variable_def = null,
  42. TF_DataType dtype = TF_DataType.DtInvalid,
  43. string import_scope = "") : base(initial_value,
  44. trainable,
  45. collections,
  46. validate_shape,
  47. caching_device,
  48. name,
  49. dtype)
  50. {
  51. _in_graph_mode = true;
  52. if (variable_def != null)
  53. {
  54. if (initial_value != null)
  55. throw new ValueError("variable_def and initial_value are mutually exclusive.");
  56. _init_from_proto(variable_def, import_scope: import_scope);
  57. }
  58. else
  59. {
  60. _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
  61. }
  62. }
  63. private void _init_from_proto(VariableDef variable_def, string import_scope = "")
  64. {
  65. var g = ops.get_default_graph();
  66. _variable = g.as_graph_element(
  67. ops.prepend_name_scope(variable_def.VariableName,
  68. import_scope: import_scope)) as Tensor;
  69. _initializer_op = g.as_graph_element(
  70. ops.prepend_name_scope(variable_def.InitializerName,
  71. import_scope: import_scope)) as Operation;
  72. // Tests whether initial_value_name exists first for backwards compatibility.
  73. if (!string.IsNullOrEmpty(variable_def.InitialValueName))
  74. _initial_value = g.as_graph_element(
  75. ops.prepend_name_scope(variable_def.InitialValueName,
  76. import_scope: import_scope)) as Tensor;
  77. else
  78. _initial_value = null;
  79. _trainable = variable_def.Trainable;
  80. _snapshot = g.as_graph_element(
  81. ops.prepend_name_scope(variable_def.SnapshotName,
  82. import_scope: import_scope)) as Tensor;
  83. if (variable_def.SaveSliceInfoDef != null)
  84. throw new NotImplementedException("save_slice_info_def");
  85. else
  86. ;// _save_slice_info = null;
  87. //_caching_device = null;
  88. //_constraint = null;
  89. }
  90. private void _init_from_args(object initial_value,
  91. bool trainable = true,
  92. List<string> collections = null,
  93. bool validate_shape = true,
  94. string caching_device = "",
  95. string name = null,
  96. TF_DataType dtype = TF_DataType.DtInvalid)
  97. {
  98. if (initial_value is null)
  99. throw new ValueError("initial_value must be specified.");
  100. var init_from_fn = initial_value.GetType().Name == "Func`1";
  101. if(collections == null)
  102. {
  103. collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES };
  104. }
  105. // Store the graph key so optimizers know how to only retrieve variables from
  106. // this graph.
  107. _graph_key = ops.get_default_graph().graph_key;
  108. _trainable = trainable;
  109. if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
  110. collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
  111. tf_with(ops.init_scope2(), delegate
  112. {
  113. var values = init_from_fn ? new object[0] : new object[] { initial_value };
  114. tf_with(ops.name_scope(name, "Variable", values), scope =>
  115. {
  116. name = scope;
  117. if (init_from_fn)
  118. {
  119. // Use attr_scope and device(None) to simulate the behavior of
  120. // colocate_with when the variable we want to colocate with doesn't
  121. // yet exist.
  122. string true_name = ops.name_from_scope_name(name);
  123. var attr = new AttrValue
  124. {
  125. List = new AttrValue.Types.ListValue()
  126. };
  127. attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
  128. tf_with(ops.name_scope("Initializer"), scope2 =>
  129. {
  130. _initial_value = (initial_value as Func<Tensor>)();
  131. _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
  132. });
  133. _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
  134. }
  135. // Or get the initial value from a Tensor or Python object.
  136. else
  137. {
  138. _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);
  139. var shape = _initial_value.shape;
  140. dtype = _initial_value.dtype;
  141. _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
  142. }
  143. // Manually overrides the variable's shape with the initial value's.
  144. if (validate_shape)
  145. {
  146. var initial_value_shape = _initial_value.TensorShape;
  147. if (!initial_value_shape.is_fully_defined())
  148. throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
  149. }
  150. // If 'initial_value' makes use of other variables, make sure we don't
  151. // have an issue if these other variables aren't initialized first by
  152. // using their initialized_value() method.
  153. var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);
  154. _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
  155. if (!String.IsNullOrEmpty(caching_device))
  156. {
  157. }
  158. else
  159. {
  160. ops.colocate_with(_initializer_op);
  161. _snapshot = gen_array_ops.identity(_variable, name = "read");
  162. }
  163. ops.add_to_collections(collections, this as VariableV1);
  164. });
  165. });
  166. }
  167. public Tensor _ref() => _variable;
  168. public Tensor value() => _snapshot;
  169. public Tensor _AsTensor() => _snapshot;
  170. public Tensor _as_graph_element() => _variable;
  171. public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  172. {
  173. if (as_ref)
  174. return _ref();
  175. else
  176. return value();
  177. }
  178. /// <summary>
  179. /// Attempt to guard against dependencies on uninitialized variables.
  180. /// </summary>
  181. /// <param name="initial_value"></param>
  182. private Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value)
  183. {
  184. return _safe_initial_value_from_tensor(name, initial_value, op_cache: new Dictionary<string, Operation>());
  185. }
  186. /// <summary>
  187. /// Replace dependencies on variables with their initialized values.
  188. /// </summary>
  189. /// <param name="tensor">A `Tensor`. The tensor to replace.</param>
  190. /// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
  191. /// <returns>A `Tensor` compatible with `tensor`.</returns>
  192. private Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary<string, Operation> op_cache)
  193. {
  194. var op = tensor.op;
  195. var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null;
  196. if(new_op == null)
  197. {
  198. new_op = _safe_initial_value_from_op(name, op, op_cache);
  199. op_cache[op.name] = new_op;
  200. }
  201. return new_op.outputs[tensor.value_index];
  202. }
  203. private Operation _safe_initial_value_from_op(string name, Operation op, Dictionary<string, Operation> op_cache)
  204. {
  205. var op_type = op.node_def.Op;
  206. switch (op_type)
  207. {
  208. case "IsVariableInitialized":
  209. case "VarIsInitializedOp":
  210. case "ReadVariableOp":
  211. return op;
  212. case "Variable":
  213. case "VariableV2":
  214. case "VarHandleOp":
  215. var initialized_value = _find_initialized_value_for_variable(op);
  216. return initialized_value == null ? op : initialized_value.op;
  217. }
  218. // Recursively build initializer expressions for inputs.
  219. var modified = false;
  220. var new_op_inputs = new List<Tensor>();
  221. foreach (var op_input in op.inputs)
  222. {
  223. var new_op_input = _safe_initial_value_from_tensor(name, op_input as Tensor, op_cache);
  224. new_op_inputs.Add(new_op_input);
  225. modified = modified || new_op_input != op_input;
  226. }
  227. // If at least one input was modified, replace the op.
  228. if (modified)
  229. {
  230. var new_op_type = op_type;
  231. if (new_op_type == "RefSwitch")
  232. new_op_type = "Switch";
  233. var new_op_name = op.node_def.Name + "_" + name;
  234. new_op_name = new_op_name.Replace(":", "_");
  235. // Convert attr values to AttrValue protos.
  236. var attr_protos = new Dictionary<string, AttrValue>();
  237. foreach (var attr_def in op.node_def.Attr)
  238. attr_protos[attr_def.Key] = attr_def.Value;
  239. return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
  240. name: new_op_name, attrs: attr_protos);
  241. }
  242. return op;
  243. }
  244. private Operation _find_initialized_value_for_variable(Operation variable_op)
  245. {
  246. var var_names = new[] { variable_op.node_def.Name, variable_op.node_def.Name + ":0" };
  247. foreach(var collection_name in new[]{tf.GraphKeys.GLOBAL_VARIABLES,
  248. tf.GraphKeys.LOCAL_VARIABLES })
  249. {
  250. foreach (var var in variable_op.graph.get_collection<RefVariable>(collection_name))
  251. if (var_names.Contains(var.name))
  252. return var.initialized_value();
  253. }
  254. return null;
  255. }
  256. /// <summary>
  257. /// Assigns a new value to the variable.
  258. /// </summary>
  259. /// <param name="value">The new value for this variable.</param>
  260. /// <param name="use_locking">If `True`, use locking during the assignment.</param>
  261. /// <param name="name">The name of the operation to be created</param>
  262. /// <param name="read_value">
  263. /// if True, will return something which evaluates to the
  264. /// new value of the variable; if False will return the assign op.
  265. /// </param>
  266. /// <returns>
  267. /// A `Tensor` that will hold the new value of this variable after
  268. /// the assignment has completed.
  269. /// </returns>
  270. public ITensorOrOperation assign(object value, bool use_locking = false, string name = null, bool read_value = true)
  271. {
  272. var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
  273. if (read_value)
  274. return assign;
  275. return assign.op;
  276. }
  277. public override string ToString()
  278. {
  279. return $"tf.RefVariable '{name}' shape={shape} dtype={dtype}";
  280. }
  281. public VariableDef to_proto(string export_scope)
  282. {
  283. if(string.IsNullOrEmpty(export_scope) || _variable.name.StartsWith(export_scope))
  284. {
  285. var var_def = new VariableDef();
  286. var_def.VariableName = ops.strip_name_scope(_variable.name, export_scope);
  287. if (_initial_value != null)
  288. var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope);
  289. var_def.Trainable = _trainable;
  290. var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope);
  291. var_def.SnapshotName = ops.strip_name_scope(_snapshot.name, export_scope);
  292. if (_save_slice_info)
  293. throw new NotImplementedException("to_proto _save_slice_info");
  294. return var_def;
  295. }
  296. throw new NotImplementedException("to_proto RefVariable");
  297. }
  298. public RefVariable from_proto(VariableDef proto, string import_scope)
  299. {
  300. throw new NotImplementedException();
  301. }
  302. /// <summary>
  303. /// Returns the value of this variable, read in the current context.
  304. /// </summary>
  305. /// <returns></returns>
  306. private ITensorOrOperation read_value()
  307. {
  308. return array_ops.identity(_variable, name: "read");
  309. }
  310. /// <summary>
  311. /// Returns the Tensor used as the initial value for the variable.
  312. /// </summary>
  313. /// <returns></returns>
  314. private ITensorOrOperation initial_value()
  315. {
  316. return _initial_value;
  317. }
  318. public Tensor is_variable_initialized(RefVariable variable)
  319. {
  320. return state_ops.is_variable_initialized(variable);
  321. }
  322. public Tensor initialized_value()
  323. {
  324. ops.init_scope();
  325. return control_flow_ops.cond(is_variable_initialized(this),
  326. read_value,
  327. initial_value);
  328. }
  329. }
  330. }