You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionGenerator.cs 18 kB


  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Linq;
  5. using System.Linq.Expressions;
  6. using System.Reflection.Metadata.Ecma335;
  7. using System.Text;
  8. using System.Threading.Tasks;
  9. using Microsoft.CodeAnalysis.CSharp;
  10. namespace Tensorflow.CodeGen
  11. {
  12. public class FunctionGenerator
  13. {
  14. public void AppendFunction(OpDef op, StringBuilder sb)
  15. {
  16. // TODO: add descriptions
  17. sb.Append("public static ");
  18. int outputArgsCount = op.OutputArg.Count;
  19. if (outputArgsCount == 0)
  20. {
  21. sb.Append("Operation ");
  22. }
  23. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  24. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  25. {
  26. sb.Append("Tensor ");
  27. }
  28. else
  29. {
  30. sb.Append("Tensor[] ");
  31. }
  32. string funcName = Utils.ConvertToUnderscore(op.Name);
  33. var token = SyntaxFactory.ParseToken(funcName);
  34. if (token.IsKeyword())
  35. {
  36. funcName = $"_{funcName}";
  37. }
  38. sb.Append($" {funcName}(");
  39. // define args
  40. AppendArgs(op, sb);
  41. sb.Append(")\n{\n");
  42. // begin to write main body
  43. sb.AppendLine("var _ctx = tf.Context;");
  44. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
  45. // deal with dynamic default values.
  46. foreach(var (name, expr) in dynamicDefaultValues)
  47. {
  48. sb.AppendLine($"if({name} is null)");
  49. sb.AppendLine("{");
  50. sb.AppendLine($"{name} = {expr};");
  51. sb.AppendLine("}");
  52. }
  53. sb.AppendLine("if(_ctx.executing_eagerly()){");
  54. if(HasRefArgs(op))
  55. {
  56. var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
  57. sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");");
  58. }
  59. else
  60. {
  61. sb.Append("try\n{\n");
  62. AppendFastPathExecute(op, sb);
  63. if (outputArgsCount == 0)
  64. {
  65. sb.AppendLine("return null;");
  66. }
  67. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  68. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  69. {
  70. sb.AppendLine("return _fast_path_result[0];");
  71. }
  72. else
  73. {
  74. sb.AppendLine("return _fast_path_result;");
  75. }
  76. sb.AppendLine("}"); // try
  77. sb.Append("catch(Exception)\n{\n");
  78. sb.AppendLine("}"); // catch
  79. sb.Append("try\n{\n");
  80. AppendEagerFallbackCall(op, sb);
  81. sb.AppendLine("}"); // try
  82. sb.Append("catch(Exception)\n{\n");
  83. sb.AppendLine("}"); // catch
  84. }
  85. sb.AppendLine("}"); // if
  86. foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string"))
  87. {
  88. if(value != "NOVALUE")
  89. {
  90. sb.AppendLine($"if({name} is null)");
  91. sb.AppendLine("{");
  92. sb.AppendLine($"{name} = {value};");
  93. sb.AppendLine("}");
  94. }
  95. }
  96. // begin to use op helper.
  97. AppendOpHelperCall(op, sb);
  98. sb.AppendLine("var _result = _op.outputs;");
  99. // check if it needs to record gradient.
  100. sb.Append("if(_execute.must_record_gradient())\n{\n");
  101. sb.Append("object[] _attrs = new object[]{");
  102. foreach (var attr in op.Attr)
  103. {
  104. string attrRealName = attr.Name;
  105. if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
  106. {
  107. attrRealName += "_";
  108. }
  109. if (attr.Type == "type")
  110. {
  111. sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), ");
  112. }
  113. else if (attr.Type == "int")
  114. {
  115. sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), ");
  116. }
  117. else if (attr.Type == "bool")
  118. {
  119. sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), ");
  120. }
  121. else
  122. {
  123. sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), ");
  124. }
  125. }
  126. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  127. {
  128. sb.Remove(sb.Length - 2, 2);
  129. }
  130. sb.Append("};\n");
  131. sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);");
  132. sb.AppendLine("}"); // if
  133. if (outputArgsCount == 0)
  134. {
  135. sb.AppendLine("return _op;");
  136. }
  137. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  138. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  139. {
  140. sb.AppendLine("return _result[0];");
  141. }
  142. else
  143. {
  144. sb.AppendLine("return _result;");
  145. }
  146. sb.AppendLine("}"); // body
  147. sb.AppendLine();
  148. AppendEagerFallbackDefinition(op, sb);
  149. }
  150. public void AppendArgs(OpDef op, StringBuilder sb)
  151. {
  152. foreach (var arg in op.InputArg)
  153. {
  154. string argName = arg.Name;
  155. var token = SyntaxFactory.ParseToken(argName);
  156. if (token.IsKeyword())
  157. {
  158. argName = $"{argName}_";
  159. }
  160. if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr))
  161. {
  162. sb.Append($"Tensors {argName}, ");
  163. }
  164. else
  165. {
  166. sb.Append($"Tensor {argName}, ");
  167. }
  168. }
  169. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
  170. foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE"))
  171. {
  172. var token = SyntaxFactory.ParseToken(key);
  173. string realKey = key;
  174. if (token.IsKeyword())
  175. {
  176. realKey += "_";
  177. }
  178. sb.Append($"{typeStr} {realKey}, ");
  179. }
  180. foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE"))
  181. {
  182. var token = SyntaxFactory.ParseToken(key);
  183. string realKey = key;
  184. if (token.IsKeyword())
  185. {
  186. realKey += "_";
  187. }
  188. sb.Append($"{typeStr} {realKey} = {value}, ");
  189. }
  190. sb.Append($"string? name = null");
  191. }
  192. public void AppendFastPathExecute(OpDef op, StringBuilder sb)
  193. {
  194. sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)");
  195. sb.Append("{ args = new object[]{ ");
  196. foreach (var arg in op.InputArg)
  197. {
  198. string attrArgName = arg.Name;
  199. if (SyntaxFactory.ParseToken(attrArgName).IsKeyword())
  200. {
  201. attrArgName += "_";
  202. }
  203. sb.Append($"{attrArgName}, ");
  204. }
  205. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  206. {
  207. sb.Remove(sb.Length - 2, 2);
  208. }
  209. sb.Append("}, attrs = new Dictionary<string, object>(){ ");
  210. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  211. foreach (var (key, _, _) in attrValueDic)
  212. {
  213. sb.Append($"[\"{key}\"] = {key}, ");
  214. }
  215. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  216. {
  217. sb.Remove(sb.Length - 2, 2);
  218. }
  219. sb.Append("}});\n");
  220. }
  221. public void AppendEagerFallbackCall(OpDef op, StringBuilder sb)
  222. {
  223. string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback";
  224. sb.Append($"return {funcName}(");
  225. foreach (var arg in op.InputArg)
  226. {
  227. string inputArgRealName = arg.Name;
  228. if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword())
  229. {
  230. inputArgRealName += "_";
  231. }
  232. sb.Append($"{inputArgRealName}, ");
  233. }
  234. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  235. foreach (var (key, _, _) in attrValueDic)
  236. {
  237. string keyRealName = key;
  238. if (SyntaxFactory.ParseToken(keyRealName).IsKeyword())
  239. {
  240. keyRealName += "_";
  241. }
  242. sb.Append($"{key}: {keyRealName}, ");
  243. }
  244. sb.Append("name: name, ctx: _ctx);\n");
  245. }
  246. public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb)
  247. {
  248. sb.Append("public static ");
  249. int outputArgsCount = op.OutputArg.Count;
  250. if (outputArgsCount == 0)
  251. {
  252. sb.Append("Operation ");
  253. }
  254. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  255. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  256. {
  257. sb.Append("Tensor ");
  258. }
  259. else
  260. {
  261. sb.Append("Tensor[] ");
  262. }
  263. string opName = op.Name;
  264. string funcName = Utils.ConvertToUnderscore(op.Name);
  265. sb.Append($" {funcName}_eager_fallback(");
  266. AppendFallBackFunctionArgs(op, sb);
  267. sb.Append(")\n{\n");
  268. var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
  269. if (possibleRefArg is not null)
  270. {
  271. sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." +
  272. $" Arg '{possibleRefArg.Name}' is a ref.\");");
  273. sb.AppendLine("}"); // body
  274. return;
  275. }
  276. if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr)))
  277. {
  278. sb.AppendLine("List<Tensor> _inputs_flat_list = new();");
  279. foreach (var arg in op.InputArg)
  280. {
  281. string realArgName = arg.Name;
  282. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  283. {
  284. realArgName = $"{realArgName}_";
  285. }
  286. if (string.IsNullOrEmpty(arg.NumberAttr))
  287. {
  288. sb.AppendLine($"_inputs_flat_list.Add({realArgName});");
  289. }
  290. else
  291. {
  292. sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});");
  293. }
  294. }
  295. sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();");
  296. }
  297. else
  298. {
  299. sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
  300. foreach (var arg in op.InputArg)
  301. {
  302. string realArgName = arg.Name;
  303. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  304. {
  305. realArgName = $"{realArgName}_";
  306. }
  307. sb.Append($"{realArgName}, ");
  308. }
  309. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  310. {
  311. sb.Remove(sb.Length - 2, 2);
  312. }
  313. sb.Append("};\n");
  314. }
  315. sb.Append("object[] _attrs = new object[]{");
  316. foreach (var attr in op.Attr)
  317. {
  318. if (attr.Type == "type")
  319. {
  320. bool found = false;
  321. foreach (var arg in op.InputArg)
  322. {
  323. string realArgName = arg.Name;
  324. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  325. {
  326. realArgName = $"{realArgName}_";
  327. }
  328. if (arg.TypeAttr == attr.Name)
  329. {
  330. sb.Append($"\"{attr.Name}\", {realArgName}.dtype, ");
  331. found = true;
  332. break;
  333. }
  334. }
  335. if (!found)
  336. {
  337. string attrRealName = attr.Name;
  338. if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
  339. {
  340. attrRealName = $"{attrRealName}_";
  341. }
  342. sb.Append($"\"{attr.Name}\", {attrRealName}, ");
  343. }
  344. }
  345. else if(attr.Type == "list(type)")
  346. {
  347. if (op.InputArg.Any(x => x.TypeListAttr == attr.Name))
  348. {
  349. continue;
  350. }
  351. }
  352. else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name))
  353. {
  354. bool found = false;
  355. foreach (var arg in op.InputArg)
  356. {
  357. string realArgName = arg.Name;
  358. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  359. {
  360. realArgName = $"{realArgName}_";
  361. }
  362. if (arg.NumberAttr == attr.Name)
  363. {
  364. sb.Append($"\"{attr.Name}\", {realArgName}.Length, ");
  365. found = true;
  366. break;
  367. }
  368. }
  369. }
  370. else
  371. {
  372. sb.Append($"\"{attr.Name}\", {attr.Name}, ");
  373. }
  374. }
  375. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  376. {
  377. sb.Remove(sb.Length - 2, 2);
  378. }
  379. sb.Append("};\n");
  380. sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " +
  381. $"attrs: _attrs, ctx: ctx, name: name);");
  382. sb.Append("if(_execute.must_record_gradient())\n{\n");
  383. sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);");
  384. sb.AppendLine("}"); // if
  385. if (outputArgsCount == 0)
  386. {
  387. sb.AppendLine("return null;");
  388. }
  389. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  390. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  391. {
  392. sb.AppendLine("return _result[0];");
  393. }
  394. else
  395. {
  396. sb.AppendLine("return _result;");
  397. }
  398. sb.AppendLine("}"); // body
  399. }
  400. public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb)
  401. {
  402. foreach (var arg in op.InputArg)
  403. {
  404. string argName = arg.Name;
  405. var token = SyntaxFactory.ParseToken(argName);
  406. if (token.IsKeyword())
  407. {
  408. argName = $"{argName}_";
  409. }
  410. if (!string.IsNullOrEmpty(arg.NumberAttr))
  411. {
  412. sb.Append($"Tensors {argName}, ");
  413. }
  414. else
  415. {
  416. sb.Append($"Tensor {argName}, ");
  417. }
  418. }
  419. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  420. foreach (var (key, typeStr, _) in attrValueDic)
  421. {
  422. var token = SyntaxFactory.ParseToken(key);
  423. string realKey = key;
  424. if (token.IsKeyword())
  425. {
  426. realKey += "_";
  427. }
  428. sb.Append($"{typeStr} {realKey}, ");
  429. }
  430. sb.Append($"string name, Context ctx");
  431. }
  432. public void AppendOpHelperCall(OpDef op, StringBuilder sb)
  433. {
  434. sb.AppendLine("Dictionary<string, object> keywords = new();");
  435. foreach (var arg in op.InputArg)
  436. {
  437. string realArgName = arg.Name;
  438. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  439. {
  440. realArgName += "_";
  441. }
  442. sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
  443. }
  444. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  445. foreach (var (key, _, _) in attrValueDic)
  446. {
  447. sb.AppendLine($"keywords[\"{key}\"] = {key};");
  448. }
  449. sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
  450. }
  451. private static bool HasRefArgs(OpDef op)
  452. {
  453. return op.InputArg.Any(x => x.IsRef);
  454. }
  455. }
  456. }