You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionGenerator.cs 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Linq;
  5. using System.Linq.Expressions;
  6. using System.Reflection.Metadata.Ecma335;
  7. using System.Text;
  8. using System.Threading.Tasks;
  9. using Microsoft.CodeAnalysis.CSharp;
  10. namespace Tensorflow.CodeGen
  11. {
  12. public class FunctionGenerator
  13. {
  14. public void AppendFunction(OpDef op, StringBuilder sb)
  15. {
  16. // TODO: add descriptions
  17. sb.Append("public static ");
  18. int outputArgsCount = op.OutputArg.Count;
  19. if (outputArgsCount == 0)
  20. {
  21. sb.Append("Operation ");
  22. }
  23. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  24. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  25. {
  26. sb.Append("Tensor ");
  27. }
  28. else
  29. {
  30. sb.Append("Tensor[] ");
  31. }
  32. string funcName = Utils.ConvertToUnderscore(op.Name);
  33. var token = SyntaxFactory.ParseToken(funcName);
  34. if (token.IsKeyword())
  35. {
  36. funcName = $"_{funcName}";
  37. }
  38. sb.Append($" {funcName}(");
  39. // define args
  40. AppendArgs(op, sb);
  41. sb.Append(")\n{\n");
  42. // begin to write main body
  43. sb.AppendLine("var _ctx = tf.Context;");
  44. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
  45. // deal with dynamic default values.
  46. foreach(var (name, expr) in dynamicDefaultValues)
  47. {
  48. sb.AppendLine($"if({name} is null)");
  49. sb.AppendLine("{");
  50. sb.AppendLine($"{name} = {expr};");
  51. sb.AppendLine("}");
  52. }
  53. sb.AppendLine("if(_ctx.executing_eagerly()){");
  54. if(HasRefArgs(op))
  55. {
  56. var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
  57. sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");");
  58. }
  59. else
  60. {
  61. sb.Append("try\n{\n");
  62. AppendFastPathExecute(op, sb);
  63. if (outputArgsCount == 0)
  64. {
  65. sb.AppendLine("return null;");
  66. }
  67. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  68. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  69. {
  70. sb.AppendLine("return _fast_path_result[0];");
  71. }
  72. else
  73. {
  74. sb.AppendLine("return _fast_path_result;");
  75. }
  76. sb.AppendLine("}"); // try
  77. sb.Append("catch(NotOkStatusException ex1)\n{\n");
  78. sb.AppendLine("throw ex1;");
  79. sb.AppendLine("}"); // catch
  80. sb.Append("catch(InvalidArgumentError ex2)\n{\n");
  81. sb.AppendLine("throw ex2;");
  82. sb.AppendLine("}"); // catch
  83. sb.Append("catch(Exception)\n{\n");
  84. sb.AppendLine("}"); // catch
  85. sb.Append("try\n{\n");
  86. AppendEagerFallbackCall(op, sb);
  87. sb.AppendLine("}"); // try
  88. sb.Append("catch(Exception)\n{\n");
  89. sb.AppendLine("}"); // catch
  90. }
  91. sb.AppendLine("}"); // if
  92. foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string"))
  93. {
  94. if(value != "NOVALUE")
  95. {
  96. sb.AppendLine($"if({name} is null)");
  97. sb.AppendLine("{");
  98. sb.AppendLine($"{name} = {value};");
  99. sb.AppendLine("}");
  100. }
  101. }
  102. // begin to use op helper.
  103. AppendOpHelperCall(op, sb);
  104. sb.AppendLine("var _result = _op.outputs;");
  105. // check if it needs to record gradient.
  106. sb.Append("if(_execute.must_record_gradient())\n{\n");
  107. sb.Append("object[] _attrs = new object[]{");
  108. foreach (var attr in op.Attr)
  109. {
  110. string attrRealName = attr.Name;
  111. if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
  112. {
  113. attrRealName += "_";
  114. }
  115. if (attr.Type == "type")
  116. {
  117. sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), ");
  118. }
  119. else if (attr.Type == "int")
  120. {
  121. sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), ");
  122. }
  123. else if (attr.Type == "bool")
  124. {
  125. sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), ");
  126. }
  127. else
  128. {
  129. sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), ");
  130. }
  131. }
  132. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  133. {
  134. sb.Remove(sb.Length - 2, 2);
  135. }
  136. sb.Append("};\n");
  137. sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);");
  138. sb.AppendLine("}"); // if
  139. if (outputArgsCount == 0)
  140. {
  141. sb.AppendLine("return _op;");
  142. }
  143. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  144. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  145. {
  146. sb.AppendLine("return _result[0];");
  147. }
  148. else
  149. {
  150. sb.AppendLine("return _result;");
  151. }
  152. sb.AppendLine("}"); // body
  153. sb.AppendLine();
  154. AppendEagerFallbackDefinition(op, sb);
  155. }
  156. public void AppendArgs(OpDef op, StringBuilder sb)
  157. {
  158. foreach (var arg in op.InputArg)
  159. {
  160. string argName = arg.Name;
  161. var token = SyntaxFactory.ParseToken(argName);
  162. if (token.IsKeyword())
  163. {
  164. argName = $"{argName}_";
  165. }
  166. if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr))
  167. {
  168. sb.Append($"Tensors {argName}, ");
  169. }
  170. else
  171. {
  172. sb.Append($"Tensor {argName}, ");
  173. }
  174. }
  175. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
  176. foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE"))
  177. {
  178. var token = SyntaxFactory.ParseToken(key);
  179. string realKey = key;
  180. if (token.IsKeyword())
  181. {
  182. realKey += "_";
  183. }
  184. sb.Append($"{typeStr} {realKey}, ");
  185. }
  186. foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE"))
  187. {
  188. var token = SyntaxFactory.ParseToken(key);
  189. string realKey = key;
  190. if (token.IsKeyword())
  191. {
  192. realKey += "_";
  193. }
  194. sb.Append($"{typeStr} {realKey} = {value}, ");
  195. }
  196. sb.Append($"string? name = null");
  197. }
  198. public void AppendFastPathExecute(OpDef op, StringBuilder sb)
  199. {
  200. sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)");
  201. sb.Append("{ args = new object[]{ ");
  202. foreach (var arg in op.InputArg)
  203. {
  204. string attrArgName = arg.Name;
  205. if (SyntaxFactory.ParseToken(attrArgName).IsKeyword())
  206. {
  207. attrArgName += "_";
  208. }
  209. sb.Append($"{attrArgName}, ");
  210. }
  211. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  212. {
  213. sb.Remove(sb.Length - 2, 2);
  214. }
  215. sb.Append("}, attrs = new Dictionary<string, object>(){ ");
  216. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  217. foreach (var (key, _, _) in attrValueDic)
  218. {
  219. sb.Append($"[\"{key}\"] = {key}, ");
  220. }
  221. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  222. {
  223. sb.Remove(sb.Length - 2, 2);
  224. }
  225. sb.Append("}});\n");
  226. }
  227. public void AppendEagerFallbackCall(OpDef op, StringBuilder sb)
  228. {
  229. string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback";
  230. sb.Append($"return {funcName}(");
  231. foreach (var arg in op.InputArg)
  232. {
  233. string inputArgRealName = arg.Name;
  234. if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword())
  235. {
  236. inputArgRealName += "_";
  237. }
  238. sb.Append($"{inputArgRealName}, ");
  239. }
  240. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  241. foreach (var (key, _, _) in attrValueDic)
  242. {
  243. string keyRealName = key;
  244. if (SyntaxFactory.ParseToken(keyRealName).IsKeyword())
  245. {
  246. keyRealName += "_";
  247. }
  248. sb.Append($"{key}: {keyRealName}, ");
  249. }
  250. sb.Append("name: name, ctx: _ctx);\n");
  251. }
  252. public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb)
  253. {
  254. sb.Append("public static ");
  255. int outputArgsCount = op.OutputArg.Count;
  256. if (outputArgsCount == 0)
  257. {
  258. sb.Append("Operation ");
  259. }
  260. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  261. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  262. {
  263. sb.Append("Tensor ");
  264. }
  265. else
  266. {
  267. sb.Append("Tensor[] ");
  268. }
  269. string opName = op.Name;
  270. string funcName = Utils.ConvertToUnderscore(op.Name);
  271. sb.Append($" {funcName}_eager_fallback(");
  272. AppendFallBackFunctionArgs(op, sb);
  273. sb.Append(")\n{\n");
  274. var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
  275. if (possibleRefArg is not null)
  276. {
  277. sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." +
  278. $" Arg '{possibleRefArg.Name}' is a ref.\");");
  279. sb.AppendLine("}"); // body
  280. return;
  281. }
  282. if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr)))
  283. {
  284. sb.AppendLine("List<Tensor> _inputs_flat_list = new();");
  285. foreach (var arg in op.InputArg)
  286. {
  287. string realArgName = arg.Name;
  288. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  289. {
  290. realArgName = $"{realArgName}_";
  291. }
  292. if (string.IsNullOrEmpty(arg.NumberAttr))
  293. {
  294. sb.AppendLine($"_inputs_flat_list.Add({realArgName});");
  295. }
  296. else
  297. {
  298. sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});");
  299. }
  300. }
  301. sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();");
  302. }
  303. else
  304. {
  305. sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
  306. foreach (var arg in op.InputArg)
  307. {
  308. string realArgName = arg.Name;
  309. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  310. {
  311. realArgName = $"{realArgName}_";
  312. }
  313. sb.Append($"{realArgName}, ");
  314. }
  315. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  316. {
  317. sb.Remove(sb.Length - 2, 2);
  318. }
  319. sb.Append("};\n");
  320. }
  321. sb.Append("object[] _attrs = new object[]{");
  322. foreach (var attr in op.Attr)
  323. {
  324. if (attr.Type == "type")
  325. {
  326. bool found = false;
  327. foreach (var arg in op.InputArg)
  328. {
  329. string realArgName = arg.Name;
  330. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  331. {
  332. realArgName = $"{realArgName}_";
  333. }
  334. if (arg.TypeAttr == attr.Name)
  335. {
  336. sb.Append($"\"{attr.Name}\", {realArgName}.dtype, ");
  337. found = true;
  338. break;
  339. }
  340. }
  341. if (!found)
  342. {
  343. string attrRealName = attr.Name;
  344. if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
  345. {
  346. attrRealName = $"{attrRealName}_";
  347. }
  348. sb.Append($"\"{attr.Name}\", {attrRealName}, ");
  349. }
  350. }
  351. else if(attr.Type == "list(type)")
  352. {
  353. if (op.InputArg.Any(x => x.TypeListAttr == attr.Name))
  354. {
  355. continue;
  356. }
  357. }
  358. else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name))
  359. {
  360. bool found = false;
  361. foreach (var arg in op.InputArg)
  362. {
  363. string realArgName = arg.Name;
  364. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  365. {
  366. realArgName = $"{realArgName}_";
  367. }
  368. if (arg.NumberAttr == attr.Name)
  369. {
  370. sb.Append($"\"{attr.Name}\", {realArgName}.Length, ");
  371. found = true;
  372. break;
  373. }
  374. }
  375. }
  376. else
  377. {
  378. sb.Append($"\"{attr.Name}\", {attr.Name}, ");
  379. }
  380. }
  381. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  382. {
  383. sb.Remove(sb.Length - 2, 2);
  384. }
  385. sb.Append("};\n");
  386. sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " +
  387. $"attrs: _attrs, ctx: ctx, name: name);");
  388. sb.Append("if(_execute.must_record_gradient())\n{\n");
  389. sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);");
  390. sb.AppendLine("}"); // if
  391. if (outputArgsCount == 0)
  392. {
  393. sb.AppendLine("return null;");
  394. }
  395. else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
  396. && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
  397. {
  398. sb.AppendLine("return _result[0];");
  399. }
  400. else
  401. {
  402. sb.AppendLine("return _result;");
  403. }
  404. sb.AppendLine("}"); // body
  405. }
  406. public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb)
  407. {
  408. foreach (var arg in op.InputArg)
  409. {
  410. string argName = arg.Name;
  411. var token = SyntaxFactory.ParseToken(argName);
  412. if (token.IsKeyword())
  413. {
  414. argName = $"{argName}_";
  415. }
  416. if (!string.IsNullOrEmpty(arg.NumberAttr))
  417. {
  418. sb.Append($"Tensors {argName}, ");
  419. }
  420. else
  421. {
  422. sb.Append($"Tensor {argName}, ");
  423. }
  424. }
  425. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  426. foreach (var (key, typeStr, _) in attrValueDic)
  427. {
  428. var token = SyntaxFactory.ParseToken(key);
  429. string realKey = key;
  430. if (token.IsKeyword())
  431. {
  432. realKey += "_";
  433. }
  434. sb.Append($"{typeStr} {realKey}, ");
  435. }
  436. sb.Append($"string name, Context ctx");
  437. }
  438. public void AppendOpHelperCall(OpDef op, StringBuilder sb)
  439. {
  440. sb.AppendLine("Dictionary<string, object> keywords = new();");
  441. foreach (var arg in op.InputArg)
  442. {
  443. string realArgName = arg.Name;
  444. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  445. {
  446. realArgName += "_";
  447. }
  448. sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
  449. }
  450. var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
  451. foreach (var (key, _, _) in attrValueDic)
  452. {
  453. sb.AppendLine($"keywords[\"{key}\"] = {key};");
  454. }
  455. sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
  456. }
  457. private static bool HasRefArgs(OpDef op)
  458. {
  459. return op.InputArg.Any(x => x.IsRef);
  460. }
  461. }
  462. }