You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GenOpsWriter.cs 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using Protobuf.Text;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. namespace Tensorflow.CodeGen
  8. {
  9. public class GenOpsWriter
  10. {
  11. private string _basePath;
  12. private Dictionary<string, OpDef> _opMap;
  13. private OpClassifier _opClassifier;
  14. private FunctionGenerator _g = new();
  15. public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename)
  16. {
  17. _basePath = basePath;
  18. var opDefs = ReadAllOpDefs(opDefFilename);
  19. _opMap = opDefs.Op.ToDictionary(
  20. x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x);
  21. _opClassifier = new OpClassifier(pythonFilesDirectory);
  22. }
  23. public void WriteAll()
  24. {
  25. foreach(var (target, set) in _opClassifier.OpSet)
  26. {
  27. StringBuilder sb = new StringBuilder();
  28. // Write file header.
  29. sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/");
  30. sb.AppendLine();
  31. // Add commonly used namespaces.
  32. sb.AppendLine("using Tensorflow.Eager;");
  33. sb.AppendLine("using Tensorflow.Contexts;");
  34. sb.AppendLine("using static Tensorflow.Binding;");
  35. sb.AppendLine();
  36. // Specify the namespace
  37. sb.AppendLine("namespace Tensorflow;");
  38. sb.AppendLine();
  39. // Write class name
  40. sb.AppendLine($"internal static class {target}");
  41. sb.AppendLine("{");
  42. foreach(var funcName in set)
  43. {
  44. if(_opMap.ContainsKey(funcName))
  45. {
  46. var opDef = _opMap[funcName];
  47. _g.AppendFunction(opDef, sb);
  48. }
  49. else if (funcName.StartsWith("_"))
  50. {
  51. var opDef = _opMap[funcName.Substring(1)];
  52. _g.AppendFunction(opDef, sb);
  53. }
  54. }
  55. // Close class scope.
  56. sb.AppendLine("}");
  57. string fullFilePath = Path.Combine(_basePath, $"{target}.cs");
  58. File.WriteAllText(fullFilePath, sb.ToString());
  59. }
  60. }
  61. private OpList ReadAllOpDefs(string path)
  62. {
  63. var text = File.ReadAllText(path);
  64. var opDefs = OpList.Parser.ParseText(text);
  65. return opDefs;
  66. }
  67. }
  68. }