You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FuncGraph.cs 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. using Google.Protobuf;
  2. using System;
  3. using System.Buffers;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using Tensorflow.Eager;
  7. using Tensorflow.Exceptions;
  8. using Tensorflow.Framework;
  9. using Tensorflow.Framework.Models;
  10. using Tensorflow.Functions;
  11. using Tensorflow.NumPy;
  12. using Tensorflow.Operations;
  13. using Tensorflow.Util;
  14. using static Tensorflow.Binding;
  15. namespace Tensorflow.Graphs;
  16. /// <summary>
  17. /// Graph representing a function body.
  18. /// </summary>
  19. public class FuncGraph : Graph, IDisposable
  20. {
  21. internal SafeFuncGraphHandle _func_graph_handle;
  22. internal HashSet<Tensor> _resource_tensor_inputs;
  23. internal HashSet<WeakReference<IVariableV1>> _watched_variables;
  24. internal IEnumerable<WeakReference<IVariableV1>> _weak_variables;
  25. internal object[] _structured_outputs;
  26. internal Dictionary<long, string> _output_names;
  27. public string FuncName => _graph_key;
  28. public Tensors Inputs { get; set; } = new Tensors();
  29. public Tensors Outputs { get; set; } = new Tensors();
  30. public Tensors FlatStructuredOutputs
  31. {
  32. get
  33. {
  34. List<Tensor> res = new();
  35. foreach(var obj in _structured_outputs)
  36. {
  37. if(obj is Tensor tensor)
  38. {
  39. res.Add(tensor);
  40. }
  41. else if(obj is IEnumerable<Tensor> tensors)
  42. {
  43. res.AddRange(tensors);
  44. }
  45. else
  46. {
  47. throw new TypeError("The structured outputs member should be tensor or tensors.");
  48. }
  49. }
  50. return res;
  51. }
  52. }
  53. public string Name { get; set; }
  54. public IEnumerable<IVariableV1> Variables
  55. {
  56. get
  57. {
  58. return _weak_variables.Select(v =>
  59. {
  60. if (v.TryGetTarget(out var target))
  61. {
  62. return target;
  63. }
  64. else
  65. {
  66. throw new AssertionError("Called a function referencing variables which have been deleted. " +
  67. "This likely means that function-local variables were created and " +
  68. "not referenced elsewhere in the program. This is generally a " +
  69. "mistake; consider storing variables in an object attribute on first call.");
  70. }
  71. });
  72. }
  73. internal set
  74. {
  75. _weak_variables = value.Select(x => new WeakReference<IVariableV1>(x));
  76. }
  77. }
  78. public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
  79. public Dictionary<string, AttrValue> Attrs { get; set; }
  80. internal Dictionary<long, (Tensor, Tensor)> _captures
  81. = new Dictionary<long, (Tensor, Tensor)>();
  82. public Tensor[] external_captures
  83. => _captures.Select(x => x.Value.Item1).ToArray();
  84. public (Tensor, Tensor)[] captures
  85. => _captures.Values.Select(x => x).ToArray();
  86. public Tensor[] internal_captures
  87. => _captures.Select(x => x.Value.Item2).ToArray();
  88. public Tensor[] captured_inputs
  89. => external_captures;
  90. /// <summary>
  91. /// Construct a new FuncGraph.
  92. /// </summary>
  93. public FuncGraph(string name) : base()
  94. {
  95. outer_graph = ops.get_default_graph();
  96. while (outer_graph.building_function)
  97. outer_graph = outer_graph.OuterGraph;
  98. _graph_key = Name = name;
  99. building_function = true;
  100. _weak_variables = new List<WeakReference<IVariableV1>>();
  101. _resource_tensor_inputs = new HashSet<Tensor>();
  102. _watched_variables = new HashSet<WeakReference<IVariableV1>>();
  103. }
  104. public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base()
  105. {
  106. outer_graph = ops.get_default_graph();
  107. while (outer_graph.building_function)
  108. outer_graph = outer_graph.OuterGraph;
  109. _graph_key = Name = name;
  110. building_function = true;
  111. Attrs = attrs;
  112. // Will to test if FuncGraph has memory leak
  113. // c_api.TF_DeleteGraph(_handle);
  114. _handle = handle;
  115. _weak_variables = new List<WeakReference<IVariableV1>>();
  116. _resource_tensor_inputs = new HashSet<Tensor>();
  117. _watched_variables = new HashSet<WeakReference<IVariableV1>>();
  118. }
  119. public void replace_capture(Tensor tensor, Tensor placeholder)
  120. {
  121. _captures[tensor.Id] = (tensor, placeholder);
  122. }
  123. public unsafe void ToGraph(Operation[] opers,
  124. Tensor[] inputs, Tensor[] outputs,
  125. string[] output_names)
  126. {
  127. var status = new Status();
  128. if (output_names is null)
  129. {
  130. output_names = new string[0];
  131. };
  132. _func_graph_handle = c_api.TF_GraphToFunction(_handle,
  133. _graph_key,
  134. false,
  135. opers.Length,
  136. opers.Select(x => (IntPtr)x).ToArray(),
  137. inputs.Length,
  138. inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
  139. outputs.Length,
  140. outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
  141. output_names.Length != outputs.Length ? null : output_names,
  142. IntPtr.Zero,
  143. null,
  144. status);
  145. status.Check(true);
  146. SetAttrs();
  147. // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle);
  148. // status.Check(true);
  149. c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status);
  150. status.Check(true);
  151. _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle));
  152. Inputs = inputs;
  153. // mark_as_return
  154. Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
  155. }
  156. public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true)
  157. {
  158. foreach(var (i, inp) in enumerate(inputs))
  159. inputs[i] = capture(inp);
  160. return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
  161. }
  162. const int _EAGER_CONST_THRESHOLD = 128;
  163. public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
  164. {
  165. if(tensor is EagerTensor or NDArray)
  166. {
  167. if (name == null)
  168. name = ops.uid().ToString();
  169. // Small EagerTensors are captured with Const ops
  170. if (dtypes.is_value_dtype(tensor.dtype)
  171. && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
  172. return capture_eager_tensor(tensor, name);
  173. // Large EagerTensors and resources are captured with Placeholder ops
  174. return _capture_helper(tensor, name, shape: shape);
  175. }
  176. if(tensor.graph != this)
  177. {
  178. if (name == null)
  179. name = tensor.op.name;
  180. var inner_graph = tensor.graph;
  181. while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
  182. {
  183. if (inner_graph == this)
  184. throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
  185. " in another function or code block. Use return values," +
  186. " explicit Python locals or TensorFlow collections to access" +
  187. $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
  188. inner_graph = inner_func_graph.outer_graph;
  189. }
  190. return _capture_helper(tensor, name);
  191. }
  192. return tensor;
  193. }
  194. public void watch_variable(IVariableV1 v)
  195. {
  196. if (_resource_tensor_inputs.Contains(v.Handle))
  197. {
  198. return;
  199. }
  200. _watched_variables.Add(new WeakReference<IVariableV1>(v));
  201. //this = this.outer_graph;
  202. }
  203. Tensor capture_eager_tensor(Tensor tensor, string name)
  204. {
  205. Tensor graph_const = null;
  206. if (!_captures.ContainsKey(tensor.Id))
  207. {
  208. graph_const = tf_with(ops.control_dependencies(null), ctl
  209. => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
  210. add_capture(tensor, graph_const);
  211. }
  212. else
  213. {
  214. graph_const = _captures[tensor.Id].Item2;
  215. }
  216. BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
  217. {
  218. return output_grads;
  219. };
  220. tf.Runner.RecordGradient("captured_value",
  221. new[] { graph_const }, null,
  222. new[] { tensor },
  223. getBackwardFunction: _backward_function_wrapper
  224. /*getForwardFunction: forward_function*/);
  225. return graph_const;
  226. }
  227. Tensor _capture_helper(Tensor tensor, string name, Shape shape = null)
  228. {
  229. Tensor placeholder = null;
  230. if (!_captures.ContainsKey(tensor.Id))
  231. {
  232. placeholder = _create_substitute_placeholder(tensor,
  233. name: name,
  234. dtype: tensor.dtype,
  235. shape: shape);
  236. add_capture(tensor, placeholder);
  237. }
  238. else
  239. {
  240. placeholder = _captures[tensor.Id].Item2;
  241. }
  242. BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
  243. {
  244. return output_grads;
  245. };
  246. tf.Runner.RecordGradient("captured_value",
  247. new[] { placeholder }, null,
  248. new[] { tensor },
  249. getBackwardFunction: _backward_function_wrapper
  250. /*getForwardFunction: forward_function*/);
  251. return placeholder;
  252. }
  253. void add_capture(Tensor tensor, Tensor placeholder)
  254. {
  255. _captures.Add(tensor.Id, (tensor, placeholder));
  256. Inputs.Add(placeholder);
  257. }
  258. Tensor pop_capture(Tensor tensor)
  259. {
  260. if(_captures.TryGetValue(tensor.Id, out var capture))
  261. {
  262. _captures.Remove(tensor.Id);
  263. return capture.Item2;
  264. }
  265. else
  266. {
  267. return null;
  268. }
  269. }
  270. Tensor _create_substitute_placeholder(Tensor value,
  271. string name = null,
  272. TF_DataType dtype = TF_DataType.DtInvalid,
  273. Shape shape = null)
  274. {
  275. if (shape is null)
  276. shape = value.shape;
  277. if (dtype == TF_DataType.DtInvalid)
  278. dtype = value.dtype;
  279. var placeholder = tf_with(ops.control_dependencies(null), ctl
  280. => array_ops.placeholder(dtype, shape: shape, name: name));
  281. // custom_gradient.copy_handle_data(value, placeholder)
  282. return placeholder;
  283. }
  284. void SetAttrs()
  285. {
  286. if (Attrs == null)
  287. return;
  288. foreach (var (_name, attr_value) in enumerate(Attrs))
  289. {
  290. var serialized = attr_value.ToByteArray();
  291. c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status);
  292. tf.Status.Check(true);
  293. }
  294. }
  295. public override Graph as_default()
  296. {
  297. tf.Context.graph_mode(isFunc: true);
  298. ops.set_default_graph(this);
  299. return this;
  300. }
  301. public override void Exit()
  302. {
  303. tf.Context.restore_mode();
  304. ops.pop_graph();
  305. }
  306. public void Dispose()
  307. {
  308. c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status);
  309. }
  310. public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func,
  311. object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null,
  312. FuncGraph func_graph = null, bool autograph = false, object autograph_options = null,
  313. bool add_control_dependencies = true, string[] arg_names = null,
  314. Tensor op_return_value = null, bool capture_by_value = false,
  315. bool acd_record_initial_resource_uses = false)
  316. {
  317. if(func_graph is null)
  318. {
  319. func_graph = new FuncGraph(name);
  320. }
  321. // TODO(Rinne): deal with control dependencies.
  322. func_graph.as_default();
  323. var current_scope = variable_scope.get_variable_scope();
  324. var default_use_resource = current_scope.use_resource;
  325. current_scope.use_resource = true;
  326. if(signature is not null)
  327. {
  328. args = signature;
  329. kwargs = new Dictionary<string, object>();
  330. }
  331. var func_args = _get_defun_inputs_from_args(args, arg_names);
  332. var func_kwargs = _get_defun_inputs_from_kwargs(kwargs);
  333. if(func_kwargs is not null && func_kwargs.Count > 0)
  334. {
  335. throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`.");
  336. }
  337. foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs }))
  338. {
  339. if(arg is Tensor tensor && tensor.dtype == dtypes.resource)
  340. {
  341. func_graph._resource_tensor_inputs.Add(tensor);
  342. }
  343. else if (arg is ResourceVariable variable)
  344. {
  345. func_graph._resource_tensor_inputs.Add(variable.Handle);
  346. }
  347. }
  348. // skip the assignment of `func_graph.structured_input_signature`.
  349. var flat_func_args = nest.flatten(func_args as object);
  350. var flat_func_kwargs = nest.flatten(func_kwargs as object);
  351. func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
  352. .Where(x => x is Tensor).Select(x => (Tensor)x).ToArray());
  353. //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
  354. //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);
  355. Tensor convert(object x)
  356. {
  357. if (x is null) return null;
  358. Tensor res = null;
  359. if(op_return_value is not null && x is Operation)
  360. {
  361. tf_with(ops.control_dependencies(new object[] { x }), _ =>
  362. {
  363. res = array_ops.identity(op_return_value);
  364. });
  365. }
  366. else if(x is not TensorArray)
  367. {
  368. Debug.Assert(x is Tensor);
  369. res = ops.convert_to_tensor_or_composite(x as Tensor);
  370. }
  371. else
  372. {
  373. throw new NotImplementedException($"The `TensorArray` is not supported here currently.");
  374. }
  375. if (add_control_dependencies)
  376. {
  377. // TODO(Rinne): `x = deps_ctx.mark_as_return(x)`.
  378. }
  379. return res;
  380. }
  381. if (autograph)
  382. {
  383. throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported.");
  384. }
  385. var func_outputs = func(func_args);
  386. func_outputs = variable_utils.convert_variables_to_tensors(func_outputs);
  387. func_outputs = func_outputs.Select(x => convert(x)).ToArray();
  388. // TODO(Rinne): `check_func_mutation`.
  389. current_scope.use_resource = default_use_resource;
  390. var graph_variables = func_graph._watched_variables.ToList();
  391. HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>();
  392. List<Tensor> inputs = new();
  393. foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args))
  394. {
  395. if(arg is BaseResourceVariable variable)
  396. {
  397. var resource_placeholder = func_graph.pop_capture(variable.Handle);
  398. if(resource_placeholder is null)
  399. {
  400. continue;
  401. }
  402. Debug.Assert(variable is IVariableV1);
  403. arg_variables.Add(variable as IVariableV1);
  404. inputs.Add(resource_placeholder);
  405. }
  406. else if(arg is Tensor tensor)
  407. {
  408. inputs.Add(tensor);
  409. }
  410. }
  411. var variables = graph_variables.Select(v =>
  412. {
  413. if (v.TryGetTarget(out var target))
  414. {
  415. return target;
  416. }
  417. else
  418. {
  419. return null;
  420. }
  421. }).Where(v => v is not null && !arg_variables.Contains(v));
  422. func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray();
  423. func_graph._structured_outputs = func_outputs;
  424. func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null)
  425. .Select(x => func_graph.capture(x)));
  426. func_graph.Variables = variables;
  427. func_graph.Exit();
  428. if (add_control_dependencies)
  429. {
  430. // TODO(Rinne): implement it.
  431. }
  432. return func_graph;
  433. }
  434. private static object[] _get_defun_inputs_from_args(object[] args, string[] names)
  435. {
  436. return _get_defun_inputs(args, names, args) as object[];
  437. }
  438. private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs)
  439. {
  440. // TODO(Rinne): implement it.
  441. Debug.Assert(kwargs is null || kwargs.Count == 0);
  442. return kwargs;
  443. //string[] names;
  444. //object[] args;
  445. //if(kwargs is not null && kwargs.Count > 0)
  446. //{
  447. // var sorted_kwargs = kwargs.OrderBy(x => x.Key);
  448. // names = sorted_kwargs.Select(x => x.Key).ToArray();
  449. // args = sorted_kwargs.Select(x => x.Value).ToArray();
  450. //}
  451. //else
  452. //{
  453. // names = new string[0];
  454. // args = new object[0];
  455. //}
  456. //return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>;
  457. }
  458. private static object _get_defun_inputs(object[] args, string[] names, object structured_args)
  459. {
  460. List<object> function_inputs = new();
  461. if(names is null)
  462. {
  463. names = new string[args.Length];
  464. }
  465. foreach(var (arg_value, name) in zip(args, names))
  466. {
  467. foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value))
  468. {
  469. function_inputs.Add(_get_defun_input(val, name));
  470. }
  471. }
  472. return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true);
  473. }
  474. private static object _get_defun_input(object arg, string name)
  475. {
  476. var func_graph = ops.get_default_graph() as FuncGraph;
  477. Debug.Assert(func_graph is not null);
  478. if (arg is Tensor tensor)
  479. {
  480. Tensor placeholder;
  481. try
  482. {
  483. placeholder = tf.placeholder(tensor.dtype, tensor.shape, name);
  484. }
  485. catch (ValueError)
  486. {
  487. // TODO(Rinne): Add warning here.
  488. placeholder = tf.placeholder(tensor.dtype, tensor.shape);
  489. }
  490. handle_data_util.copy_handle_data(tensor, placeholder);
  491. if (name is not null)
  492. {
  493. placeholder.op._set_attr("_user_specified_name", new AttrValue()
  494. {
  495. S = tf.compat.as_bytes(name)
  496. });
  497. }
  498. return placeholder;
  499. }
  500. else if (arg is TensorSpec spec)
  501. {
  502. string requested_name;
  503. if (!string.IsNullOrEmpty(spec.name))
  504. {
  505. requested_name = spec.name;
  506. }
  507. else
  508. {
  509. requested_name = name;
  510. }
  511. Tensor placeholder;
  512. try
  513. {
  514. placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name);
  515. }
  516. catch (ValueError)
  517. {
  518. // TODO(Rinne): Add warning here.
  519. placeholder = tf.placeholder(spec.dtype, spec.shape);
  520. }
  521. if (name is not null)
  522. {
  523. placeholder.op._set_attr("_user_specified_name", new AttrValue()
  524. {
  525. S = tf.compat.as_bytes(requested_name)
  526. });
  527. }
  528. return placeholder;
  529. }
  530. else if (arg is BaseResourceVariable variable)
  531. {
  532. var placeholder = func_graph.capture(variable.Handle, name);
  533. placeholder.op._set_attr("_user_specified_name", new AttrValue()
  534. {
  535. S = tf.compat.as_bytes(name)
  536. });
  537. return arg;
  538. }
  539. // TODO(Rinne): deal with `VariableSpec`.
  540. else
  541. {
  542. return arg;
  543. }
  544. }
  545. }