using Protobuf.Text; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace Tensorflow.CodeGen { public class GenOpsWriter { private string _basePath; private Dictionary _opMap; private OpClassifier _opClassifier; private FunctionGenerator _g = new(); public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) { _basePath = basePath; var opDefs = ReadAllOpDefs(opDefFilename); _opMap = opDefs.Op.ToDictionary( x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); _opClassifier = new OpClassifier(pythonFilesDirectory); } public void WriteAll() { foreach(var (target, set) in _opClassifier.OpSet) { StringBuilder sb = new StringBuilder(); // Write file header. sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/"); sb.AppendLine(); // Add commonly used namespaces. sb.AppendLine("using Tensorflow.Eager;"); sb.AppendLine("using Tensorflow.Contexts;"); sb.AppendLine("using static Tensorflow.Binding;"); sb.AppendLine(); // Specify the namespace sb.AppendLine("namespace Tensorflow;"); sb.AppendLine(); // Write class name sb.AppendLine($"internal static class {target}"); sb.AppendLine("{"); foreach(var funcName in set) { if(_opMap.ContainsKey(funcName)) { var opDef = _opMap[funcName]; _g.AppendFunction(opDef, sb); } else if (funcName.StartsWith("_")) { var opDef = _opMap[funcName.Substring(1)]; _g.AppendFunction(opDef, sb); } } // Close class scope. sb.AppendLine("}"); string fullFilePath = Path.Combine(_basePath, $"{target}.cs"); File.WriteAllText(fullFilePath, sb.ToString()); } } private OpList ReadAllOpDefs(string path) { var text = File.ReadAllText(path); var opDefs = OpList.Parser.ParseText(text); return opDefs; } } }