You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionGenerator.cs 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Linq;
  5. using System.Reflection.Metadata.Ecma335;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Microsoft.CodeAnalysis.CSharp;
  9. namespace Tensorflow.CodeGen
  10. {
  11. public class FunctionGenerator
  12. {
  13. public void AppendFunction(OpDef op, StringBuilder sb)
  14. {
  15. // TODO: add descriptions
  16. sb.Append("public static ");
  17. int outputArgsCount = op.OutputArg.Count;
  18. if (outputArgsCount > 1)
  19. {
  20. sb.Append("Tensor[] ");
  21. }
  22. else if (outputArgsCount == 1)
  23. {
  24. sb.Append("Tensor ");
  25. }
  26. else
  27. {
  28. sb.Append("Operation ");
  29. }
  30. string funcName = Utils.ConvertToUnderscore(op.Name);
  31. var token = SyntaxFactory.ParseToken(funcName);
  32. if (token.IsKeyword())
  33. {
  34. funcName = $"_{funcName}";
  35. }
  36. sb.Append($" {funcName}(");
  37. // define args
  38. AppendArgs(op, sb);
  39. sb.Append(")\n{\n");
  40. // begin to write main body
  41. sb.AppendLine("var _ctx = tf.Context;");
  42. sb.AppendLine("if(_ctx.executing_eagerly()){");
  43. if(HasRefArgs(op))
  44. {
  45. var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
  46. sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");");
  47. }
  48. else
  49. {
  50. sb.Append("try\n{\n");
  51. AppendFastPathExecute(op, sb);
  52. if (outputArgsCount == 0)
  53. {
  54. sb.AppendLine("return null;");
  55. }
  56. else if (outputArgsCount == 1)
  57. {
  58. sb.AppendLine("return _fast_path_result[0];");
  59. }
  60. else
  61. {
  62. sb.AppendLine("return _fast_path_result;");
  63. }
  64. sb.AppendLine("}"); // try
  65. sb.Append("catch(Exception)\n{\n");
  66. sb.AppendLine("}"); // catch
  67. sb.Append("try\n{\n");
  68. AppendEagerFallbackCall(op, sb);
  69. sb.AppendLine("}"); // try
  70. sb.Append("catch(Exception)\n{\n");
  71. sb.AppendLine("}"); // catch
  72. }
  73. sb.AppendLine("}"); // if
  74. // begin to use op helper.
  75. AppendOpHelperCall(op, sb);
  76. sb.AppendLine("var _result = _op.outputs;");
  77. // check if it needs to record gradient.
  78. sb.Append("if(_execute.must_record_gradient())\n{\n");
  79. sb.Append("object[] _attrs = new object[]{");
  80. foreach (var attr in op.Attr)
  81. {
  82. string attrRealName = attr.Name;
  83. if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
  84. {
  85. attrRealName += "_";
  86. }
  87. if (attr.Type == "type")
  88. {
  89. sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), ");
  90. }
  91. else if (attr.Type == "int")
  92. {
  93. sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), ");
  94. }
  95. else if (attr.Type == "bool")
  96. {
  97. sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), ");
  98. }
  99. else
  100. {
  101. sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), ");
  102. }
  103. }
  104. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  105. {
  106. sb.Remove(sb.Length - 2, 2);
  107. }
  108. sb.Append("};\n");
  109. sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);");
  110. sb.AppendLine("}"); // if
  111. if (outputArgsCount == 0)
  112. {
  113. sb.AppendLine("return _op;");
  114. }
  115. else if (outputArgsCount == 1)
  116. {
  117. sb.AppendLine("return _result[0];");
  118. }
  119. else
  120. {
  121. sb.AppendLine("return _result;");
  122. }
  123. sb.AppendLine("}"); // body
  124. sb.AppendLine();
  125. AppendEagerFallbackDefinition(op, sb);
  126. }
  127. public void AppendArgs(OpDef op, StringBuilder sb)
  128. {
  129. foreach (var arg in op.InputArg)
  130. {
  131. string argName = arg.Name;
  132. var token = SyntaxFactory.ParseToken(argName);
  133. if (token.IsKeyword())
  134. {
  135. argName = $"{argName}_";
  136. }
  137. if (!string.IsNullOrEmpty(arg.NumberAttr))
  138. {
  139. sb.Append($"Tensors {argName}, ");
  140. }
  141. else
  142. {
  143. sb.Append($"Tensor {argName}, ");
  144. }
  145. }
  146. var attrValueDic = GetAttrsDefaultValue(op);
  147. foreach (var (key, (typeStr, value)) in attrValueDic)
  148. {
  149. var token = SyntaxFactory.ParseToken(key);
  150. string realKey = key;
  151. if (token.IsKeyword())
  152. {
  153. realKey += "_";
  154. }
  155. if (value != "NOVALUE")
  156. {
  157. sb.Append($"{typeStr} {realKey} = {value}, ");
  158. }
  159. else
  160. {
  161. sb.Append($"{typeStr} {realKey}, ");
  162. }
  163. }
  164. sb.Append($"string? name = null");
  165. }
  166. public void AppendFastPathExecute(OpDef op, StringBuilder sb)
  167. {
  168. sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, ");
  169. foreach (var arg in op.InputArg)
  170. {
  171. string attrArgName = arg.Name;
  172. if (SyntaxFactory.ParseToken(attrArgName).IsKeyword())
  173. {
  174. attrArgName += "_";
  175. }
  176. sb.Append($"{attrArgName}, ");
  177. }
  178. var attrValueDic = GetAttrsDefaultValue(op);
  179. foreach (var (key, _) in attrValueDic)
  180. {
  181. sb.Append($"\"{key}\", {key}, ");
  182. }
  183. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  184. {
  185. sb.Remove(sb.Length - 2, 2);
  186. }
  187. sb.Append("));\n");
  188. }
  189. public void AppendEagerFallbackCall(OpDef op, StringBuilder sb)
  190. {
  191. string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback";
  192. sb.Append($"return {funcName}(");
  193. foreach (var arg in op.InputArg)
  194. {
  195. string inputArgRealName = arg.Name;
  196. if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword())
  197. {
  198. inputArgRealName += "_";
  199. }
  200. sb.Append($"{inputArgRealName}, ");
  201. }
  202. var attrValueDic = GetAttrsDefaultValue(op);
  203. foreach (var (key, _) in attrValueDic)
  204. {
  205. string keyRealName = key;
  206. if (SyntaxFactory.ParseToken(keyRealName).IsKeyword())
  207. {
  208. keyRealName += "_";
  209. }
  210. sb.Append($"{key}: {keyRealName}, ");
  211. }
  212. sb.Append("name: name, ctx: _ctx);\n");
  213. }
  214. public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb)
  215. {
  216. sb.Append("public static Tensor");
  217. int outputArgsCount = op.OutputArg.Count;
  218. if (outputArgsCount > 1)
  219. {
  220. sb.Append("[]");
  221. }
  222. string opName = op.Name;
  223. string funcName = Utils.ConvertToUnderscore(op.Name);
  224. sb.Append($" {funcName}_eager_fallback(");
  225. AppendFallBackFunctionArgs(op, sb);
  226. sb.Append(")\n{\n");
  227. var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
  228. if (possibleRefArg is not null)
  229. {
  230. sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." +
  231. $" Arg '{possibleRefArg.Name}' is a ref.\");");
  232. sb.AppendLine("}"); // body
  233. return;
  234. }
  235. sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
  236. foreach (var arg in op.InputArg)
  237. {
  238. string realArgName = arg.Name;
  239. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  240. {
  241. realArgName = $"{realArgName}_";
  242. }
  243. sb.Append($"{realArgName}, ");
  244. }
  245. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  246. {
  247. sb.Remove(sb.Length - 2, 2);
  248. }
  249. sb.Append("};\n");
  250. sb.Append("object[] _attrs = new object[]{");
  251. var attrValueDic = GetAttrsDefaultValue(op);
  252. foreach (var attr in op.Attr)
  253. {
  254. if (attr.Type == "type")
  255. {
  256. bool found = false;
  257. foreach (var arg in op.InputArg)
  258. {
  259. string realArgName = arg.Name;
  260. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  261. {
  262. realArgName = $"{realArgName}_";
  263. }
  264. if (arg.TypeAttr == attr.Name)
  265. {
  266. sb.Append($"\"{attr.Name}\", {realArgName}.dtype, ");
  267. found = true;
  268. break;
  269. }
  270. }
  271. if (!found)
  272. {
  273. if (attr.Name.StartsWith("T") && attr.Name.Length > 1)
  274. {
  275. string paramName = attr.Name.Substring(1);
  276. if (SyntaxFactory.ParseToken(paramName).IsKeyword())
  277. {
  278. paramName = $"{paramName}_";
  279. }
  280. sb.Append($"\"{attr.Name}\", {paramName}.dtype, ");
  281. }
  282. else
  283. {
  284. string attrRealName = attr.Name;
  285. if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
  286. {
  287. attrRealName = $"{attrRealName}_";
  288. }
  289. sb.Append($"\"{attr.Name}\", {attrRealName}, ");
  290. }
  291. }
  292. }
  293. else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)))
  294. {
  295. bool found = false;
  296. foreach (var arg in op.InputArg)
  297. {
  298. string realArgName = arg.Name;
  299. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  300. {
  301. realArgName = $"{realArgName}_";
  302. }
  303. if (arg.NumberAttr == attr.Name)
  304. {
  305. sb.Append($"\"{attr.Name}\", {realArgName}.Length, ");
  306. found = true;
  307. break;
  308. }
  309. }
  310. }
  311. else
  312. {
  313. sb.Append($"\"{attr.Name}\", {attr.Name}, ");
  314. }
  315. }
  316. if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
  317. {
  318. sb.Remove(sb.Length - 2, 2);
  319. }
  320. sb.Append("};\n");
  321. sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " +
  322. $"attrs: _attrs, ctx: ctx, name: name);");
  323. sb.Append("if(_execute.must_record_gradient())\n{\n");
  324. sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);");
  325. sb.AppendLine("}"); // if
  326. if (outputArgsCount == 0)
  327. {
  328. sb.AppendLine("return null;");
  329. }
  330. else if (outputArgsCount == 1)
  331. {
  332. sb.AppendLine("return _result[0];");
  333. }
  334. else
  335. {
  336. sb.AppendLine("return _result;");
  337. }
  338. sb.AppendLine("}"); // body
  339. }
  340. public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb)
  341. {
  342. foreach (var arg in op.InputArg)
  343. {
  344. string argName = arg.Name;
  345. var token = SyntaxFactory.ParseToken(argName);
  346. if (token.IsKeyword())
  347. {
  348. argName = $"{argName}_";
  349. }
  350. if (!string.IsNullOrEmpty(arg.NumberAttr))
  351. {
  352. sb.Append($"Tensors {argName}, ");
  353. }
  354. else
  355. {
  356. sb.Append($"Tensor {argName}, ");
  357. }
  358. }
  359. var attrValueDic = GetAttrsDefaultValue(op);
  360. foreach (var (key, (typeStr, _)) in attrValueDic)
  361. {
  362. var token = SyntaxFactory.ParseToken(key);
  363. string realKey = key;
  364. if (token.IsKeyword())
  365. {
  366. realKey += "_";
  367. }
  368. sb.Append($"{typeStr} {realKey}, ");
  369. }
  370. sb.Append($"string name, Context ctx");
  371. }
  372. public void AppendOpHelperCall(OpDef op, StringBuilder sb)
  373. {
  374. sb.AppendLine("Dictionary<string, object> keywords = new();");
  375. foreach (var arg in op.InputArg)
  376. {
  377. string realArgName = arg.Name;
  378. if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
  379. {
  380. realArgName += "_";
  381. }
  382. sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
  383. }
  384. var attrValueDic = GetAttrsDefaultValue(op);
  385. foreach (var (key, _) in attrValueDic)
  386. {
  387. sb.Append($"keywords[\"{key}\"] = {key};");
  388. }
  389. sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
  390. }
  391. // key, (type string, default value)
  392. public Dictionary<string, (string, string)> GetAttrsDefaultValue(OpDef op)
  393. {
  394. Dictionary<string, (string, string)> dic = new();
  395. foreach (var attr in op.Attr)
  396. {
  397. if (attr.Type == "type")
  398. {
  399. bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name);
  400. if (!found)
  401. {
  402. if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
  403. {
  404. string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype());
  405. string enumPath = typeof(TF_DataType).Name + "." + name;
  406. dic[attr.Name] = ("TF_DataType", enumPath);
  407. }
  408. else
  409. {
  410. dic[attr.Name] = ("TF_DataType", "NOVALUE");
  411. }
  412. }
  413. }
  414. else if (attr.Type == "int")
  415. {
  416. if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))
  417. {
  418. continue;
  419. }
  420. if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I)
  421. {
  422. dic[attr.Name] = ("int", attr.DefaultValue.I.ToString());
  423. }
  424. else
  425. {
  426. dic[attr.Name] = ("int", "0");
  427. }
  428. }
  429. else if (attr.Type == "float")
  430. {
  431. if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F)
  432. {
  433. dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f");
  434. }
  435. else
  436. {
  437. dic[attr.Name] = ("float", "NOVALUE");
  438. }
  439. }
  440. else if (attr.Type == "string")
  441. {
  442. if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
  443. {
  444. dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"");
  445. }
  446. else
  447. {
  448. dic[attr.Name] = ("string", "NOVALUE");
  449. }
  450. }
  451. else if (attr.Type == "bool")
  452. {
  453. if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B)
  454. {
  455. dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower());
  456. }
  457. else
  458. {
  459. dic[attr.Name] = ("bool", "NOVALUE");
  460. }
  461. }
  462. else if (attr.Type == "shape")
  463. {
  464. if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape)
  465. {
  466. dic[attr.Name] = ("Shape", $"null");
  467. }
  468. else
  469. {
  470. dic[attr.Name] = ("Shape", "NOVALUE");
  471. }
  472. }
  473. else if (attr.Type == "list(type)")
  474. {
  475. dic[attr.Name] = ("TF_DataType[]", "NOVALUE");
  476. }
  477. else if (attr.Type == "list(shape)")
  478. {
  479. dic[attr.Name] = ("Shape[]", "NOVALUE");
  480. }
  481. else if (attr.Type == "list(string)")
  482. {
  483. dic[attr.Name] = ("string[]", "NOVALUE");
  484. }
  485. else if (attr.Type == "list(int)")
  486. {
  487. dic[attr.Name] = ("int[]", "NOVALUE");
  488. }
  489. else if (attr.Type == "list(float)")
  490. {
  491. dic[attr.Name] = ("float[]", "NOVALUE");
  492. }
  493. else if (attr.Type == "func")
  494. {
  495. dic[attr.Name] = ("Func<Tensors, Tensors>", "NOVALUE");
  496. }
  497. else if (attr.Type == "list(func)")
  498. {
  499. dic[attr.Name] = ("Func<Tensors, Tensors>[]", "NOVALUE");
  500. }
  501. else if (attr.Type == "tensor")
  502. {
  503. dic[attr.Name] = ("TensorProto", "NOVALUE");
  504. }
  505. else
  506. {
  507. throw new NotImplementedException();
  508. }
  509. }
  510. return dic;
  511. }
  512. private static bool HasRefArgs(OpDef op)
  513. {
  514. return op.InputArg.Any(x => x.IsRef);
  515. }
  516. }
  517. }