You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GenOpsWriter.cs 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using Protobuf.Text;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. namespace Tensorflow.CodeGen
  8. {
  9. public class GenOpsWriter
  10. {
  11. private string _basePath;
  12. private Dictionary<string, OpDef> _opMap;
  13. private OpClassifier _opClassifier;
  14. private FunctionGenerator _fg = new();
  15. private DescriptionGenerator _dg;
  16. public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename)
  17. {
  18. _basePath = basePath;
  19. var opDefs = Utils.ReadAllOpDefs(opDefFilename);
  20. _opMap = opDefs.Op.ToDictionary(
  21. x => Utils.ConvertToUnderscore(x.Name), x => x);
  22. _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name)));
  23. _dg = new DescriptionGenerator(apiDefFilesDirectory);
  24. }
  25. public void WriteAll()
  26. {
  27. foreach(var (target, set) in _opClassifier.OpSet)
  28. {
  29. StringBuilder sb = new StringBuilder();
  30. // Write file header.
  31. sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/");
  32. sb.AppendLine();
  33. // Add commonly used namespaces.
  34. sb.AppendLine("using Tensorflow.Eager;");
  35. sb.AppendLine("using Tensorflow.Contexts;");
  36. sb.AppendLine("using static Tensorflow.Binding;");
  37. sb.AppendLine();
  38. // Specify the namespace
  39. sb.AppendLine("namespace Tensorflow;");
  40. sb.AppendLine();
  41. // Write class name
  42. sb.AppendLine($"public static class {target}");
  43. sb.AppendLine("{");
  44. foreach(var funcName in set)
  45. {
  46. if(_opMap.ContainsKey(funcName))
  47. {
  48. var opDef = _opMap[funcName];
  49. // write the descriptions.
  50. _dg.AppendDescription(opDef, sb);
  51. // write the function body.
  52. _fg.AppendFunction(opDef, sb);
  53. }
  54. else if (funcName.StartsWith("_"))
  55. {
  56. var opDef = _opMap[funcName.Substring(1)];
  57. _fg.AppendFunction(opDef, sb);
  58. }
  59. }
  60. // Close class scope.
  61. sb.AppendLine("}");
  62. string fullFilePath = Path.Combine(_basePath, $"{target}.cs");
  63. File.WriteAllText(fullFilePath, sb.ToString());
  64. }
  65. }
  66. }
  67. }