You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GenOpsWriter.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. using Protobuf.Text;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. namespace Tensorflow.CodeGen
  8. {
  9. public class GenOpsWriter
  10. {
  11. private string _basePath;
  12. private Dictionary<string, OpDef> _opMap;
  13. private OpClassifier _opClassifier;
  14. private FunctionGenerator _fg = new();
  15. private DescriptionGenerator _dg;
  16. public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename)
  17. {
  18. _basePath = basePath;
  19. var opDefs = Utils.ReadAllOpDefs(opDefFilename);
  20. _opMap = opDefs.Op.ToDictionary(
  21. x => Utils.ConvertToUnderscore(x.Name), x => x);
  22. _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name)));
  23. _dg = new DescriptionGenerator(apiDefFilesDirectory);
  24. }
  25. public void WriteAll()
  26. {
  27. foreach(var (target, set) in _opClassifier.OpSet)
  28. {
  29. StringBuilder sb = new StringBuilder();
  30. // Write file header.
  31. sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/");
  32. sb.AppendLine();
  33. // Add commonly used namespaces.
  34. sb.AppendLine("using Tensorflow.Eager;");
  35. sb.AppendLine("using Tensorflow.Contexts;");
  36. sb.AppendLine("using Tensorflow.Exceptions;");
  37. sb.AppendLine("using static Tensorflow.Binding;");
  38. sb.AppendLine();
  39. // Specify the namespace
  40. sb.AppendLine("namespace Tensorflow;");
  41. sb.AppendLine();
  42. // Write class name
  43. sb.AppendLine($"public static class {target}");
  44. sb.AppendLine("{");
  45. foreach(var funcName in set)
  46. {
  47. if(_opMap.ContainsKey(funcName))
  48. {
  49. var opDef = _opMap[funcName];
  50. // write the descriptions.
  51. _dg.AppendDescription(opDef, sb);
  52. // write the function body.
  53. _fg.AppendFunction(opDef, sb);
  54. }
  55. else if (funcName.StartsWith("_"))
  56. {
  57. var opDef = _opMap[funcName.Substring(1)];
  58. _fg.AppendFunction(opDef, sb);
  59. }
  60. }
  61. // Close class scope.
  62. sb.AppendLine("}");
  63. string fullFilePath = Path.Combine(_basePath, $"{target}.cs");
  64. File.WriteAllText(fullFilePath, sb.ToString());
  65. }
  66. }
  67. }
  68. }