diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 0c7d6e3c..8d548814 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -35,6 +35,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "helpers", "helpers", "{E1A5 EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest.RedistHolder", "helpers\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj", "{62D543A2-8846-45A3-829B-5754B094A8E2}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.CodeGen", "Tensorflow.CodeGen\Tensorflow.CodeGen.csproj", "{BADBB104-2F03-4824-A249-803A871D8122}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "protobuf.Text", "..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj", "{151B3A8A-8576-4190-BD58-F42944A49718}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -282,6 +286,42 @@ Global {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x64.Build.0 = Release|Any CPU {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.ActiveCfg = Release|Any CPU {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.Build.0 = Release|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Debug|x64.ActiveCfg = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Debug|x64.Build.0 = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Debug|x86.ActiveCfg = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Debug|x86.Build.0 = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.GPU|Any CPU.Build.0 = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.GPU|x64.ActiveCfg = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.GPU|x64.Build.0 = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.GPU|x86.ActiveCfg = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.GPU|x86.Build.0 = Debug|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Release|Any CPU.Build.0 = Release|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Release|x64.ActiveCfg = Release|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Release|x64.Build.0 = Release|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Release|x86.ActiveCfg = Release|Any CPU + {BADBB104-2F03-4824-A249-803A871D8122}.Release|x86.Build.0 = Release|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Debug|Any CPU.Build.0 = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x64.ActiveCfg = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x64.Build.0 = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x86.ActiveCfg = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x86.Build.0 = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.GPU|Any CPU.Build.0 = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x64.ActiveCfg = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x64.Build.0 = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x86.ActiveCfg = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x86.Build.0 = Debug|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Release|Any CPU.ActiveCfg = Release|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Release|Any CPU.Build.0 = Release|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Release|x64.ActiveCfg = Release|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Release|x64.Build.0 = Release|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Release|x86.ActiveCfg = Release|Any CPU + {151B3A8A-8576-4190-BD58-F42944A49718}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -300,6 +340,7 @@ Global {9738D16A-CFA0-405C-A7DF-D3D203B0CB18} = {01A1787F-A9BE-4221-84E8-6360DD010AB6} {7DEA8760-E401-4872-81F3-405F185A13A0} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} {62D543A2-8846-45A3-829B-5754B094A8E2} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} + {BADBB104-2F03-4824-A249-803A871D8122} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A} diff --git a/Tensorflow.CodeGen/FunctionGenerator.cs b/Tensorflow.CodeGen/FunctionGenerator.cs new file mode 100644 index 00000000..d4520307 --- /dev/null +++ b/Tensorflow.CodeGen/FunctionGenerator.cs @@ -0,0 +1,550 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection.Metadata.Ecma335; +using System.Text; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.CSharp; + +namespace Tensorflow.CodeGen +{ + public class FunctionGenerator + { + public void AppendFunction(OpDef op, StringBuilder sb) + { + // TODO: add descriptions + sb.Append("public static "); + int outputArgsCount = op.OutputArg.Count; + if (outputArgsCount > 1) + { + sb.Append("Tensor[] "); + } + else if (outputArgsCount == 1) + { + sb.Append("Tensor "); + } + else + { + sb.Append("Operation "); + } + string funcName = Utils.ConvertToUnderscore(op.Name); + var token = SyntaxFactory.ParseToken(funcName); + if (token.IsKeyword()) + { + funcName = $"_{funcName}"; + } + sb.Append($" {funcName}("); + + // define args + AppendArgs(op, sb); + sb.Append(")\n{\n"); + + // begin to write main body + sb.AppendLine("var _ctx = tf.Context;"); + sb.AppendLine("if(_ctx.executing_eagerly()){"); + + if(HasRefArgs(op)) + { + var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); + sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");"); + } + else + { + sb.Append("try\n{\n"); + + AppendFastPathExecute(op, sb); + if (outputArgsCount == 0) + { + sb.AppendLine("return null;"); + } + else if (outputArgsCount == 1) + { + sb.AppendLine("return _fast_path_result[0];"); + } + else + { + sb.AppendLine("return _fast_path_result;"); + } + + sb.AppendLine("}"); // try + + sb.Append("catch(Exception)\n{\n"); + sb.AppendLine("}"); // catch + + sb.Append("try\n{\n"); + AppendEagerFallbackCall(op, sb); + sb.AppendLine("}"); // try + + sb.Append("catch(Exception)\n{\n"); + sb.AppendLine("}"); // catch + } + + sb.AppendLine("}"); // if + + // begin to use op helper. + AppendOpHelperCall(op, sb); + sb.AppendLine("var _result = _op.outputs;"); + + // check if it needs to record gradient. + sb.Append("if(_execute.must_record_gradient())\n{\n"); + sb.Append("object[] _attrs = new object[]{"); + foreach (var attr in op.Attr) + { + string attrRealName = attr.Name; + if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) + { + attrRealName += "_"; + } + if (attr.Type == "type") + { + sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), "); + } + else if (attr.Type == "int") + { + sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), "); + } + else if (attr.Type == "bool") + { + sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), "); + } + else + { + sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), "); + } + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); + sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);"); + + sb.AppendLine("}"); // if + + if (outputArgsCount == 0) + { + sb.AppendLine("return _op;"); + } + else if (outputArgsCount == 1) + { + sb.AppendLine("return _result[0];"); + } + else + { + sb.AppendLine("return _result;"); + } + sb.AppendLine("}"); // body + + sb.AppendLine(); + + AppendEagerFallbackDefinition(op, sb); + } + + public void AppendArgs(OpDef op, StringBuilder sb) + { + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + if (!string.IsNullOrEmpty(arg.NumberAttr)) + { + sb.Append($"Tensors {argName}, "); + } + else + { + sb.Append($"Tensor {argName}, "); + } + } + var attrValueDic = GetAttrsDefaultValue(op); + foreach (var (key, (typeStr, value)) in attrValueDic) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + if (value != "NOVALUE") + { + sb.Append($"{typeStr} {realKey} = {value}, "); + } + else + { + sb.Append($"{typeStr} {realKey}, "); + } + } + sb.Append($"string? name = null"); + } + + public void AppendFastPathExecute(OpDef op, StringBuilder sb) + { + sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, "); + foreach (var arg in op.InputArg) + { + string attrArgName = arg.Name; + if (SyntaxFactory.ParseToken(attrArgName).IsKeyword()) + { + attrArgName += "_"; + } + sb.Append($"{attrArgName}, "); + } + var attrValueDic = GetAttrsDefaultValue(op); + foreach (var (key, _) in attrValueDic) + { + sb.Append($"\"{key}\", {key}, "); + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("));\n"); + } + + public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) + { + string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback"; + sb.Append($"return {funcName}("); + foreach (var arg in op.InputArg) + { + string inputArgRealName = arg.Name; + if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword()) + { + inputArgRealName += "_"; + } + sb.Append($"{inputArgRealName}, "); + } + var attrValueDic = GetAttrsDefaultValue(op); + foreach (var (key, _) in attrValueDic) + { + string keyRealName = key; + if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) + { + keyRealName += "_"; + } + sb.Append($"{key}: {keyRealName}, "); + } + sb.Append("name: name, ctx: _ctx);\n"); + } + + public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) + { + sb.Append("public static Tensor"); + int outputArgsCount = op.OutputArg.Count; + if (outputArgsCount > 1) + { + sb.Append("[]"); + } + string opName = op.Name; + string funcName = Utils.ConvertToUnderscore(op.Name); + sb.Append($" {funcName}_eager_fallback("); + AppendFallBackFunctionArgs(op, sb); + sb.Append(")\n{\n"); + + var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); + if (possibleRefArg is not null) + { + sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." + + $" Arg '{possibleRefArg.Name}' is a ref.\");"); + sb.AppendLine("}"); // body + return; + } + + sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + sb.Append($"{realArgName}, "); + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); + + sb.Append("object[] _attrs = new object[]{"); + var attrValueDic = GetAttrsDefaultValue(op); + foreach (var attr in op.Attr) + { + if (attr.Type == "type") + { + bool found = false; + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + if (arg.TypeAttr == attr.Name) + { + sb.Append($"\"{attr.Name}\", {realArgName}.dtype, "); + found = true; + break; + } + } + if (!found) + { + if (attr.Name.StartsWith("T") && attr.Name.Length > 1) + { + string paramName = attr.Name.Substring(1); + if (SyntaxFactory.ParseToken(paramName).IsKeyword()) + { + paramName = $"{paramName}_"; + } + sb.Append($"\"{attr.Name}\", {paramName}.dtype, "); + } + else + { + string attrRealName = attr.Name; + if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) + { + attrRealName = $"{attrRealName}_"; + } + sb.Append($"\"{attr.Name}\", {attrRealName}, "); + } + } + } + else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))) + { + bool found = false; + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + if (arg.NumberAttr == attr.Name) + { + sb.Append($"\"{attr.Name}\", {realArgName}.Length, "); + found = true; + break; + } + } + } + else + { + sb.Append($"\"{attr.Name}\", {attr.Name}, "); + } + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); + + sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " + + $"attrs: _attrs, ctx: ctx, name: name);"); + + sb.Append("if(_execute.must_record_gradient())\n{\n"); + + sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);"); + + sb.AppendLine("}"); // if + + if (outputArgsCount == 0) + { + sb.AppendLine("return null;"); + } + else if (outputArgsCount == 1) + { + sb.AppendLine("return _result[0];"); + } + else + { + sb.AppendLine("return _result;"); + } + + sb.AppendLine("}"); // body + } + + public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb) + { + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + if (!string.IsNullOrEmpty(arg.NumberAttr)) + { + sb.Append($"Tensors {argName}, "); + } + else + { + sb.Append($"Tensor {argName}, "); + } + } + var attrValueDic = GetAttrsDefaultValue(op); + foreach (var (key, (typeStr, _)) in attrValueDic) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + sb.Append($"{typeStr} {realKey}, "); + } + sb.Append($"string name, Context ctx"); + } + + public void AppendOpHelperCall(OpDef op, StringBuilder sb) + { + sb.AppendLine("Dictionary keywords = new();"); + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName += "_"; + } + sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); + } + var attrValueDic = GetAttrsDefaultValue(op); + foreach (var (key, _) in attrValueDic) + { + sb.Append($"keywords[\"{key}\"] = {key};"); + } + sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); + } + + // key, (type string, default value) + public Dictionary GetAttrsDefaultValue(OpDef op) + { + Dictionary dic = new(); + foreach (var attr in op.Attr) + { + if (attr.Type == "type") + { + bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); + if (!found) + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) + { + string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); + string enumPath = typeof(TF_DataType).Name + "." + name; + dic[attr.Name] = ("TF_DataType", enumPath); + } + else + { + dic[attr.Name] = ("TF_DataType", "NOVALUE"); + } + } + } + else if (attr.Type == "int") + { + if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)) + { + continue; + } + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) + { + dic[attr.Name] = ("int", attr.DefaultValue.I.ToString()); + } + else + { + dic[attr.Name] = ("int", "0"); + } + } + else if (attr.Type == "float") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) + { + dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f"); + } + else + { + dic[attr.Name] = ("float", "NOVALUE"); + } + } + else if (attr.Type == "string") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) + { + dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""); + } + else + { + dic[attr.Name] = ("string", "NOVALUE"); + } + } + else if (attr.Type == "bool") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) + { + dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower()); + } + else + { + dic[attr.Name] = ("bool", "NOVALUE"); + } + } + else if (attr.Type == "shape") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) + { + dic[attr.Name] = ("Shape", $"null"); + } + else + { + dic[attr.Name] = ("Shape", "NOVALUE"); + } + } + else if (attr.Type == "list(type)") + { + dic[attr.Name] = ("TF_DataType[]", "NOVALUE"); + } + else if (attr.Type == "list(shape)") + { + dic[attr.Name] = ("Shape[]", "NOVALUE"); + } + else if (attr.Type == "list(string)") + { + dic[attr.Name] = ("string[]", "NOVALUE"); + } + else if (attr.Type == "list(int)") + { + dic[attr.Name] = ("int[]", "NOVALUE"); + } + else if (attr.Type == "list(float)") + { + dic[attr.Name] = ("float[]", "NOVALUE"); + } + else if (attr.Type == "func") + { + dic[attr.Name] = ("Func", "NOVALUE"); + } + else if (attr.Type == "list(func)") + { + dic[attr.Name] = ("Func[]", "NOVALUE"); + } + else if (attr.Type == "tensor") + { + dic[attr.Name] = ("TensorProto", "NOVALUE"); + } + else + { + throw new NotImplementedException(); + } + } + return dic; + } + + private static bool HasRefArgs(OpDef op) + { + return op.InputArg.Any(x => x.IsRef); + } + } +} diff --git a/Tensorflow.CodeGen/GenOpsWriter.cs b/Tensorflow.CodeGen/GenOpsWriter.cs new file mode 100644 index 00000000..83ca6e0b --- /dev/null +++ b/Tensorflow.CodeGen/GenOpsWriter.cs @@ -0,0 +1,80 @@ +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.CodeGen +{ + public class GenOpsWriter + { + private string _basePath; + private Dictionary _opMap; + private OpClassifier _opClassifier; + private FunctionGenerator _g = new(); + + public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) + { + _basePath = basePath; + + var opDefs = ReadAllOpDefs(opDefFilename); + _opMap = opDefs.Op.ToDictionary( + x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); + _opClassifier = new OpClassifier(pythonFilesDirectory); + } + + public void WriteAll() + { + foreach(var (target, set) in _opClassifier.OpSet) + { + StringBuilder sb = new StringBuilder(); + + // Write file header. + sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/"); + sb.AppendLine(); + + // Add commonly used namespaces. + sb.AppendLine("using Tensorflow.Eager;"); + sb.AppendLine("using Tensorflow.Contexts;"); + sb.AppendLine("using static Tensorflow.Binding;"); + sb.AppendLine(); + + // Specify the namespace + sb.AppendLine("namespace Tensorflow;"); + sb.AppendLine(); + + // Write class name + sb.AppendLine($"internal static class {target}"); + sb.AppendLine("{"); + + foreach(var funcName in set) + { + if(_opMap.ContainsKey(funcName)) + { + var opDef = _opMap[funcName]; + _g.AppendFunction(opDef, sb); + } + else if (funcName.StartsWith("_")) + { + var opDef = _opMap[funcName.Substring(1)]; + _g.AppendFunction(opDef, sb); + } + } + + // Close class scope. + sb.AppendLine("}"); + + string fullFilePath = Path.Combine(_basePath, $"{target}.cs"); + File.WriteAllText(fullFilePath, sb.ToString()); + } + } + + private OpList ReadAllOpDefs(string path) + { + var text = File.ReadAllText(path); + var opDefs = OpList.Parser.ParseText(text); + return opDefs; + } + } +} diff --git a/Tensorflow.CodeGen/OpClassifier.cs b/Tensorflow.CodeGen/OpClassifier.cs new file mode 100644 index 00000000..2ea2f35e --- /dev/null +++ b/Tensorflow.CodeGen/OpClassifier.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Text.RegularExpressions; + +namespace Tensorflow.CodeGen +{ + public class OpClassifier + { + private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$"; + private static readonly string _pythonFunctionPattern = @"def\s+(\w+)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*\w+\s*=None\s*\):"; + private Dictionary> _opSet = new(); + public Dictionary> OpSet => _opSet; + public OpClassifier(string pythonFileFolder) + { + DirectoryInfo directory = new DirectoryInfo(pythonFileFolder); + + foreach (FileInfo file in directory.GetFiles()) + { + if (Regex.IsMatch(file.Name, _filenamePattern)) + { + string filenamePrefix = file.Name.Split('.')[0]; + string content = File.ReadAllText(file.FullName); + var matches = Regex.Matches(content, _pythonFunctionPattern); + foreach(Match match in matches) + { + var funcName = match.Groups[1].Value; + if (!funcName.EndsWith("_eager_fallback")) + { + _opSet.SetDefault(filenamePrefix, new HashSet()).Add(funcName); + } + } + } + } + } + } +} diff --git a/Tensorflow.CodeGen/Program.cs b/Tensorflow.CodeGen/Program.cs new file mode 100644 index 00000000..d46dcdcb --- /dev/null +++ b/Tensorflow.CodeGen/Program.cs @@ -0,0 +1,12 @@ +using OneOf.Types; +using Protobuf.Text; +using System.Diagnostics; +using System.Text; +using System.Xml.Linq; +using Tensorflow.CodeGen; + +GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", + @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", + @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); + +writer.WriteAll(); diff --git a/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj b/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj new file mode 100644 index 00000000..61273d01 --- /dev/null +++ b/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj @@ -0,0 +1,18 @@ + + + + Exe + net6.0 + enable + enable + + + + + + + + + + + diff --git a/Tensorflow.CodeGen/Utils.cs b/Tensorflow.CodeGen/Utils.cs new file mode 100644 index 00000000..8cf21dee --- /dev/null +++ b/Tensorflow.CodeGen/Utils.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection.Metadata.Ecma335; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.CodeGen +{ + public static class Utils + { + public static string ConvertToUnderscore(string input) + { + if (string.IsNullOrEmpty(input)) + { + return input; + } + + StringBuilder result = new StringBuilder(); + + int state = 0; // the previous char was not lowered. + for (int i = 0; i < input.Length; i++) + { + char current = input[i]; + + // 首字母不需要添加下划线 + if (i != 0 && char.IsUpper(current)) + { + if(state == 0) + { + result.Append("_"); + state = 1; + } + result.Append(char.ToLower(current)); + } + else + { + result.Append(char.ToLower(current)); + state = 0; + } + } + + return result.ToString(); + } + } +}