From 28642568a22a242dbbddd375472ee1aeb90e7dce Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 8 May 2023 01:57:18 +0800 Subject: [PATCH] feat: description generator of op code. --- Tensorflow.CodeGen/DescriptionGenerator.cs | 263 +++++++++++++++++++ Tensorflow.CodeGen/FunctionGenerator.cs | 201 +------------- Tensorflow.CodeGen/GenOpsWriter.cs | 26 +- Tensorflow.CodeGen/Program.cs | 3 +- Tensorflow.CodeGen/Tensorflow.CodeGen.csproj | 2 +- Tensorflow.CodeGen/Utils.cs | 199 +++++++++++++- 6 files changed, 482 insertions(+), 212 deletions(-) create mode 100644 Tensorflow.CodeGen/DescriptionGenerator.cs diff --git a/Tensorflow.CodeGen/DescriptionGenerator.cs b/Tensorflow.CodeGen/DescriptionGenerator.cs new file mode 100644 index 00000000..0437370a --- /dev/null +++ b/Tensorflow.CodeGen/DescriptionGenerator.cs @@ -0,0 +1,263 @@ +using Microsoft.CodeAnalysis.CSharp; +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection.Metadata.Ecma335; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; + +namespace Tensorflow.CodeGen +{ + public class DescriptionGenerator + { + private static readonly string replaceStrInner = "~~%~~"; + private static readonly string replaceStrInnerQuotationMarks = "^%^"; + Dictionary> _opDescriptions = new Dictionary>(); + Dictionary _opDescriptionDefs = new Dictionary(); + public DescriptionGenerator(string apiDefDirectory) + { + DirectoryInfo directory = new DirectoryInfo(apiDefDirectory); + + int errors = 0; + foreach (FileInfo file in directory.GetFiles()) + { + string target = file.Name.Split('.')[0].Split('_').Last(); + OpDef op = null; + try + { + op = ReadOpDefs(file.FullName).Op[0]; + } + catch + { + errors++; + continue; + } + _opDescriptionDefs[target] = op; + _opDescriptions[target] = new Dictionary(); + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + _opDescriptions[target][argName] = arg.Description ?? ""; + } + foreach (var arg in op.Attr) + { + var token = SyntaxFactory.ParseToken(arg.Name); + string realKey = arg.Name; + if (token.IsKeyword()) + { + realKey += "_"; + } + _opDescriptions[target][realKey] = arg.Description ?? ""; + } + _opDescriptions[target]["SUMMARY"] = op.Summary ?? ""; + _opDescriptions[target]["DESC"] = op.Description ?? ""; + } + Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " + + $"the failed files number is large, or ignore it."); + } + + /// + /// + /// + /// + /// + public void AppendDescription(OpDef fullOp, StringBuilder sb) + { + var opName = fullOp.Name; + if(_opDescriptions.TryGetValue(opName, out var op)) + { + var def = _opDescriptionDefs[opName]; + sb.AppendLine("/// "); + sb.AppendLine($"/// {op["SUMMARY"]}"); + sb.AppendLine("/// "); + + string totalDesc = op["DESC"]; + if (!string.IsNullOrEmpty(totalDesc)) + { + totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\""); + sb.AppendLine("/// "); + string[] lines = totalDesc.Split(replaceStrInner); + foreach (var line in lines) + { + sb.AppendLine($"/// {line}"); + } + sb.AppendLine("/// "); + } + + var argNames = GetInputArgNames(fullOp); + foreach (var argName in argNames) + { + if(op.TryGetValue(argName, out var desc)) + { + desc = desc.Replace(replaceStrInnerQuotationMarks, "\""); + string[] lines = desc.Split(replaceStrInner); + sb.AppendLine($"/// "); + foreach (var line in lines) + { + sb.AppendLine($"/// {line}"); + } + sb.AppendLine("/// "); + } + else + { + sb.AppendLine($"/// "); + } + } + + List returnValueDescs = new(); + foreach (var arg in def.OutputArg) + { + if (!string.IsNullOrEmpty(arg.Description)) + { + returnValueDescs.Add($"{arg.Name}: {arg.Description}"); + } + } + string returnValueDesc = ""; + if (returnValueDescs.Count > 0) + { + returnValueDesc = string.Join(" && ", returnValueDescs); + } + sb.AppendLine($"/// {returnValueDesc}"); + } + else + { + sb.AppendLine("/// "); + sb.AppendLine($"///"); + sb.AppendLine("/// "); + + var argNames = GetInputArgNames(fullOp); + foreach (var argName in argNames) + { + sb.AppendLine($"/// "); + } + + sb.AppendLine($"/// "); + } + } + + /// + /// + /// + /// + /// + /// + /// + public List GetInputArgNames(OpDef op) + { + List names = new(); + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + names.Add(argName); + } + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); + foreach (var (key, typeStr, value) in attrValueDic) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + names.Add(realKey); + } + return names; + } + + private static OpList ReadOpDefs(string path) + { + var text = File.ReadAllText(path); + text = RemoveLintTags(text); + text = PreProcessText(text); + + string pattern = @"< { + string matchedText = match.Value; + string innerText = match.Groups[1].Value; + innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks) + .Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符 + return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾 + }, RegexOptions.Multiline); + + var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse(replacedText); + return opDefs; + } + + static string PreProcessText(string input) + { + int depth = 0; + int endBlockDepth = -1; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < input.Length; i++) + { + char c = input[i]; + if (c == '{') + { + depth++; + sb.Append(c); + } + else if (c == '}') + { + if (depth == endBlockDepth) + { + sb.Append("END\n"); + endBlockDepth = -1; + } + sb.Append(c); + depth--; + } + else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "< x.Item3 == "NOVALUE")) { var token = SyntaxFactory.ParseToken(key); @@ -226,7 +226,7 @@ namespace Tensorflow.CodeGen } sb.Append("}, attrs = new Dictionary(){ "); - var attrValueDic = GetAttrsDefaultValue(op, out var _); + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); foreach (var (key, _, _) in attrValueDic) { sb.Append($"[\"{key}\"] = {key}, "); @@ -252,7 +252,7 @@ namespace Tensorflow.CodeGen } sb.Append($"{inputArgRealName}, "); } - var attrValueDic = GetAttrsDefaultValue(op, out var _); + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); foreach (var (key, _, _) in attrValueDic) { string keyRealName = key; @@ -439,7 +439,7 @@ namespace Tensorflow.CodeGen sb.Append($"Tensor {argName}, "); } } - var attrValueDic = GetAttrsDefaultValue(op, out var _); + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); foreach (var (key, typeStr, _) in attrValueDic) { var token = SyntaxFactory.ParseToken(key); @@ -465,7 +465,7 @@ namespace Tensorflow.CodeGen } sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); } - var attrValueDic = GetAttrsDefaultValue(op, out var _); + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); foreach (var (key, _, _) in attrValueDic) { sb.AppendLine($"keywords[\"{key}\"] = {key};"); @@ -473,195 +473,6 @@ namespace Tensorflow.CodeGen sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); } - // name, type string, default value - public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary dynamicDefaultValues) - { - dynamicDefaultValues = new(); - List<(string, string, string)> res = new(); - foreach (var attr in op.Attr) - { - if (attr.Type == "type") - { - bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); - if (!found) - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) - { - string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); - string enumPath = typeof(TF_DataType).Name + "." + name; - res.Add((attr.Name, "TF_DataType", enumPath)); - } - else - { - res.Add((attr.Name, "TF_DataType", "NOVALUE")); - } - } - } - else if (attr.Type == "int") - { - if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) - { - continue; - } - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) - { - res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); - } - else - { - res.Add((attr.Name, "int", "0")); - } - } - else if (attr.Type == "float") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) - { - res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); - } - else - { - res.Add((attr.Name, "float", "NOVALUE")); - } - } - else if (attr.Type == "string") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) - { - res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); - } - else - { - res.Add((attr.Name, "string", "NOVALUE")); - } - } - else if (attr.Type == "bool") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) - { - res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); - } - else - { - res.Add((attr.Name, "bool", "NOVALUE")); - } - } - else if (attr.Type == "shape") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) - { - if (attr.DefaultValue.Shape.UnknownRank) - { - res.Add((attr.Name, "Shape", $"null")); - } - else - { - Shape shape = new Shape(attr.DefaultValue.Shape); - string expression = $"new Shape({string.Join(", ", shape.dims)})"; - dynamicDefaultValues[attr.Name] = expression; - res.Add((attr.Name, "Shape", $"null")); - } - } - else - { - res.Add((attr.Name, "Shape", "NOVALUE")); - } - } - else if (attr.Type == "list(type)") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) - { - List values = new(); - foreach (var value in attr.DefaultValue.List.Type) - { - values.Add(value.as_tf_dtype()); - } - string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; - dynamicDefaultValues[attr.Name] = expression; - res.Add((attr.Name, "TF_DataType[]", $"null")); - } - else - { - res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); - } - } - else if (attr.Type == "list(shape)") - { - res.Add((attr.Name, "Shape[]", "NOVALUE")); - } - else if (attr.Type == "list(string)") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) - { - List values = new(); - foreach (var value in attr.DefaultValue.List.S) - { - values.Add(value.ToStringUtf8()); - } - string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; - dynamicDefaultValues[attr.Name] = expression; - res.Add((attr.Name, "string[]", $"null")); - } - else - { - res.Add((attr.Name, "string[]", "NOVALUE")); - } - } - else if (attr.Type == "list(int)") - { - if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) - { - List values = new(); - foreach(var value in attr.DefaultValue.List.I) - { - values.Add((int)value); - } - string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; - dynamicDefaultValues[attr.Name] = expression; - res.Add((attr.Name, "int[]", $"null")); - } - else - { - res.Add((attr.Name, "int[]", "NOVALUE")); - } - } - else if (attr.Type == "list(float)") - { - if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) - { - List values = new(); - foreach (var value in attr.DefaultValue.List.F) - { - values.Add(value); - } - string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; - dynamicDefaultValues[attr.Name] = expression; - res.Add((attr.Name, "float[]", $"null")); - } - else - { - res.Add((attr.Name, "float[]", "NOVALUE")); - } - } - else if (attr.Type == "func") - { - res.Add((attr.Name, "Func", "NOVALUE")); - } - else if (attr.Type == "list(func)") - { - res.Add((attr.Name, "Func[]", "NOVALUE")); - } - else if (attr.Type == "tensor") - { - res.Add((attr.Name, "TensorProto", "NOVALUE")); - } - else - { - throw new NotImplementedException(); - } - } - return res; - } - private static bool HasRefArgs(OpDef op) { return op.InputArg.Any(x => x.IsRef); diff --git a/Tensorflow.CodeGen/GenOpsWriter.cs b/Tensorflow.CodeGen/GenOpsWriter.cs index 2cd7bca5..7601acdb 100644 --- a/Tensorflow.CodeGen/GenOpsWriter.cs +++ b/Tensorflow.CodeGen/GenOpsWriter.cs @@ -12,16 +12,18 @@ namespace Tensorflow.CodeGen private string _basePath; private Dictionary _opMap; private OpClassifier _opClassifier; - private FunctionGenerator _g = new(); + private FunctionGenerator _fg = new(); + private DescriptionGenerator _dg; - public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) + public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename) { _basePath = basePath; - var opDefs = ReadAllOpDefs(opDefFilename); + var opDefs = Utils.ReadAllOpDefs(opDefFilename); _opMap = opDefs.Op.ToDictionary( - x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); + x => Utils.ConvertToUnderscore(x.Name), x => x); _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); + _dg = new DescriptionGenerator(apiDefFilesDirectory); } public void WriteAll() @@ -53,12 +55,17 @@ namespace Tensorflow.CodeGen if(_opMap.ContainsKey(funcName)) { var opDef = _opMap[funcName]; - _g.AppendFunction(opDef, sb); + + // write the descriptions. + _dg.AppendDescription(opDef, sb); + + // write the function body. + _fg.AppendFunction(opDef, sb); } else if (funcName.StartsWith("_")) { var opDef = _opMap[funcName.Substring(1)]; - _g.AppendFunction(opDef, sb); + _fg.AppendFunction(opDef, sb); } } @@ -69,12 +76,5 @@ namespace Tensorflow.CodeGen File.WriteAllText(fullFilePath, sb.ToString()); } } - - private OpList ReadAllOpDefs(string path) - { - var text = File.ReadAllText(path); - var opDefs = OpList.Parser.ParseText(text); - return opDefs; - } } } diff --git a/Tensorflow.CodeGen/Program.cs b/Tensorflow.CodeGen/Program.cs index a26031cb..f9d44ce8 100644 --- a/Tensorflow.CodeGen/Program.cs +++ b/Tensorflow.CodeGen/Program.cs @@ -5,10 +5,9 @@ using System.Text; using System.Xml.Linq; using Tensorflow.CodeGen; -//Console.WriteLine(Utils.ConvertToUnderscore("LRN")); - GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", + @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); writer.WriteAll(); diff --git a/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj b/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj index a052eb69..865db126 100644 --- a/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj +++ b/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj @@ -9,11 +9,11 @@ - + diff --git a/Tensorflow.CodeGen/Utils.cs b/Tensorflow.CodeGen/Utils.cs index 608222e0..d3f30d9f 100644 --- a/Tensorflow.CodeGen/Utils.cs +++ b/Tensorflow.CodeGen/Utils.cs @@ -1,4 +1,5 @@ -using System; +using Protobuf.Text; +using System; using System.Collections.Generic; using System.Linq; using System.Reflection.Metadata.Ecma335; @@ -51,5 +52,201 @@ namespace Tensorflow.CodeGen return result.ToString(); } + + public static OpList ReadAllOpDefs(string path) + { + var text = File.ReadAllText(path); + var opDefs = OpList.Parser.ParseText(text); + return opDefs; + } + + // name, type string, default value + public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary dynamicDefaultValues) + { + dynamicDefaultValues = new(); + List<(string, string, string)> res = new(); + foreach (var attr in op.Attr) + { + if (attr.Type == "type") + { + bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); + if (!found) + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) + { + string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); + string enumPath = typeof(TF_DataType).Name + "." + name; + res.Add((attr.Name, "TF_DataType", enumPath)); + } + else + { + res.Add((attr.Name, "TF_DataType", "NOVALUE")); + } + } + } + else if (attr.Type == "int") + { + if (op.InputArg.Any(x => x.NumberAttr == attr.Name)) + { + continue; + } + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) + { + res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); + } + else + { + res.Add((attr.Name, "int", "0")); + } + } + else if (attr.Type == "float") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) + { + res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); + } + else + { + res.Add((attr.Name, "float", "NOVALUE")); + } + } + else if (attr.Type == "string") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) + { + res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); + } + else + { + res.Add((attr.Name, "string", "NOVALUE")); + } + } + else if (attr.Type == "bool") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) + { + res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); + } + else + { + res.Add((attr.Name, "bool", "NOVALUE")); + } + } + else if (attr.Type == "shape") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) + { + if (attr.DefaultValue.Shape.UnknownRank) + { + res.Add((attr.Name, "Shape", $"null")); + } + else + { + Shape shape = new Shape(attr.DefaultValue.Shape); + string expression = $"new Shape({string.Join(", ", shape.dims)})"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "Shape", $"null")); + } + } + else + { + res.Add((attr.Name, "Shape", "NOVALUE")); + } + } + else if (attr.Type == "list(type)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.Type) + { + values.Add(value.as_tf_dtype()); + } + string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "TF_DataType[]", $"null")); + } + else + { + res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); + } + } + else if (attr.Type == "list(shape)") + { + res.Add((attr.Name, "Shape[]", "NOVALUE")); + } + else if (attr.Type == "list(string)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.S) + { + values.Add(value.ToStringUtf8()); + } + string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "string[]", $"null")); + } + else + { + res.Add((attr.Name, "string[]", "NOVALUE")); + } + } + else if (attr.Type == "list(int)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.I) + { + values.Add((int)value); + } + string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "int[]", $"null")); + } + else + { + res.Add((attr.Name, "int[]", "NOVALUE")); + } + } + else if (attr.Type == "list(float)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.F) + { + values.Add(value); + } + string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "float[]", $"null")); + } + else + { + res.Add((attr.Name, "float[]", "NOVALUE")); + } + } + else if (attr.Type == "func") + { + res.Add((attr.Name, "Func", "NOVALUE")); + } + else if (attr.Type == "list(func)") + { + res.Add((attr.Name, "Func[]", "NOVALUE")); + } + else if (attr.Type == "tensor") + { + res.Add((attr.Name, "TensorProto", "NOVALUE")); + } + else + { + throw new NotImplementedException(); + } + } + return res; + } } }