Browse Source

feat: add code generator of ops.

tags/v0.110.0-LSTM-Model
Yaohui Liu Haiping 2 years ago
parent
commit
36b19df42d
7 changed files with 786 additions and 0 deletions
  1. +41
    -0
      TensorFlow.NET.sln
  2. +550
    -0
      Tensorflow.CodeGen/FunctionGenerator.cs
  3. +80
    -0
      Tensorflow.CodeGen/GenOpsWriter.cs
  4. +39
    -0
      Tensorflow.CodeGen/OpClassifier.cs
  5. +12
    -0
      Tensorflow.CodeGen/Program.cs
  6. +18
    -0
      Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
  7. +46
    -0
      Tensorflow.CodeGen/Utils.cs

+ 41
- 0
TensorFlow.NET.sln View File

@@ -35,6 +35,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "helpers", "helpers", "{E1A5
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest.RedistHolder", "helpers\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj", "{62D543A2-8846-45A3-829B-5754B094A8E2}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.CodeGen", "Tensorflow.CodeGen\Tensorflow.CodeGen.csproj", "{BADBB104-2F03-4824-A249-803A871D8122}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "protobuf.Text", "..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj", "{151B3A8A-8576-4190-BD58-F42944A49718}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -282,6 +286,42 @@ Global
{62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x64.Build.0 = Release|Any CPU
{62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.ActiveCfg = Release|Any CPU
{62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.Build.0 = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|Any CPU.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x64.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x64.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x86.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x86.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|Any CPU.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x64.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x64.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x86.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x86.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|Any CPU.ActiveCfg = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|Any CPU.Build.0 = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x64.ActiveCfg = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x64.Build.0 = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x86.ActiveCfg = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x86.Build.0 = Release|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Debug|Any CPU.Build.0 = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x64.ActiveCfg = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x64.Build.0 = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x86.ActiveCfg = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Debug|x86.Build.0 = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.GPU|Any CPU.Build.0 = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x64.ActiveCfg = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x64.Build.0 = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x86.ActiveCfg = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.GPU|x86.Build.0 = Debug|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Release|Any CPU.ActiveCfg = Release|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Release|Any CPU.Build.0 = Release|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Release|x64.ActiveCfg = Release|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Release|x64.Build.0 = Release|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Release|x86.ActiveCfg = Release|Any CPU
{151B3A8A-8576-4190-BD58-F42944A49718}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -300,6 +340,7 @@ Global
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
{7DEA8760-E401-4872-81F3-405F185A13A0} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
{62D543A2-8846-45A3-829B-5754B094A8E2} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
{BADBB104-2F03-4824-A249-803A871D8122} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A}


+ 550
- 0
Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -0,0 +1,550 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Text;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp;

namespace Tensorflow.CodeGen
{
public class FunctionGenerator
{
public void AppendFunction(OpDef op, StringBuilder sb)
{
// TODO: add descriptions
sb.Append("public static ");
int outputArgsCount = op.OutputArg.Count;
if (outputArgsCount > 1)
{
sb.Append("Tensor[] ");
}
else if (outputArgsCount == 1)
{
sb.Append("Tensor ");
}
else
{
sb.Append("Operation ");
}
string funcName = Utils.ConvertToUnderscore(op.Name);
var token = SyntaxFactory.ParseToken(funcName);
if (token.IsKeyword())
{
funcName = $"_{funcName}";
}
sb.Append($" {funcName}(");

// define args
AppendArgs(op, sb);
sb.Append(")\n{\n");

// begin to write main body
sb.AppendLine("var _ctx = tf.Context;");
sb.AppendLine("if(_ctx.executing_eagerly()){");

if(HasRefArgs(op))
{
var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");");
}
else
{
sb.Append("try\n{\n");

AppendFastPathExecute(op, sb);
if (outputArgsCount == 0)
{
sb.AppendLine("return null;");
}
else if (outputArgsCount == 1)
{
sb.AppendLine("return _fast_path_result[0];");
}
else
{
sb.AppendLine("return _fast_path_result;");
}

sb.AppendLine("}"); // try

sb.Append("catch(Exception)\n{\n");
sb.AppendLine("}"); // catch

sb.Append("try\n{\n");
AppendEagerFallbackCall(op, sb);
sb.AppendLine("}"); // try

sb.Append("catch(Exception)\n{\n");
sb.AppendLine("}"); // catch
}

sb.AppendLine("}"); // if

// begin to use op helper.
AppendOpHelperCall(op, sb);
sb.AppendLine("var _result = _op.outputs;");

// check if it needs to record gradient.
sb.Append("if(_execute.must_record_gradient())\n{\n");
sb.Append("object[] _attrs = new object[]{");
foreach (var attr in op.Attr)
{
string attrRealName = attr.Name;
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
{
attrRealName += "_";
}
if (attr.Type == "type")
{
sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), ");
}
else if (attr.Type == "int")
{
sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), ");
}
else if (attr.Type == "bool")
{
sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), ");
}
else
{
sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), ");
}
}
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{
sb.Remove(sb.Length - 2, 2);
}
sb.Append("};\n");
sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);");

sb.AppendLine("}"); // if

if (outputArgsCount == 0)
{
sb.AppendLine("return _op;");
}
else if (outputArgsCount == 1)
{
sb.AppendLine("return _result[0];");
}
else
{
sb.AppendLine("return _result;");
}
sb.AppendLine("}"); // body

sb.AppendLine();

AppendEagerFallbackDefinition(op, sb);
}

public void AppendArgs(OpDef op, StringBuilder sb)
{
foreach (var arg in op.InputArg)
{
string argName = arg.Name;
var token = SyntaxFactory.ParseToken(argName);
if (token.IsKeyword())
{
argName = $"{argName}_";
}
if (!string.IsNullOrEmpty(arg.NumberAttr))
{
sb.Append($"Tensors {argName}, ");
}
else
{
sb.Append($"Tensor {argName}, ");
}
}
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, (typeStr, value)) in attrValueDic)
{
var token = SyntaxFactory.ParseToken(key);
string realKey = key;
if (token.IsKeyword())
{
realKey += "_";
}
if (value != "NOVALUE")
{
sb.Append($"{typeStr} {realKey} = {value}, ");
}
else
{
sb.Append($"{typeStr} {realKey}, ");
}
}
sb.Append($"string? name = null");
}

public void AppendFastPathExecute(OpDef op, StringBuilder sb)
{
sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, ");
foreach (var arg in op.InputArg)
{
string attrArgName = arg.Name;
if (SyntaxFactory.ParseToken(attrArgName).IsKeyword())
{
attrArgName += "_";
}
sb.Append($"{attrArgName}, ");
}
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, _) in attrValueDic)
{
sb.Append($"\"{key}\", {key}, ");
}
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{
sb.Remove(sb.Length - 2, 2);
}
sb.Append("));\n");
}

public void AppendEagerFallbackCall(OpDef op, StringBuilder sb)
{
string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback";
sb.Append($"return {funcName}(");
foreach (var arg in op.InputArg)
{
string inputArgRealName = arg.Name;
if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword())
{
inputArgRealName += "_";
}
sb.Append($"{inputArgRealName}, ");
}
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, _) in attrValueDic)
{
string keyRealName = key;
if (SyntaxFactory.ParseToken(keyRealName).IsKeyword())
{
keyRealName += "_";
}
sb.Append($"{key}: {keyRealName}, ");
}
sb.Append("name: name, ctx: _ctx);\n");
}

public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb)
{
sb.Append("public static Tensor");
int outputArgsCount = op.OutputArg.Count;
if (outputArgsCount > 1)
{
sb.Append("[]");
}
string opName = op.Name;
string funcName = Utils.ConvertToUnderscore(op.Name);
sb.Append($" {funcName}_eager_fallback(");
AppendFallBackFunctionArgs(op, sb);
sb.Append(")\n{\n");

var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
if (possibleRefArg is not null)
{
sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." +
$" Arg '{possibleRefArg.Name}' is a ref.\");");
sb.AppendLine("}"); // body
return;
}

sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
foreach (var arg in op.InputArg)
{
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
{
realArgName = $"{realArgName}_";
}
sb.Append($"{realArgName}, ");
}
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{
sb.Remove(sb.Length - 2, 2);
}
sb.Append("};\n");

sb.Append("object[] _attrs = new object[]{");
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var attr in op.Attr)
{
if (attr.Type == "type")
{
bool found = false;
foreach (var arg in op.InputArg)
{
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
{
realArgName = $"{realArgName}_";
}
if (arg.TypeAttr == attr.Name)
{
sb.Append($"\"{attr.Name}\", {realArgName}.dtype, ");
found = true;
break;
}
}
if (!found)
{
if (attr.Name.StartsWith("T") && attr.Name.Length > 1)
{
string paramName = attr.Name.Substring(1);
if (SyntaxFactory.ParseToken(paramName).IsKeyword())
{
paramName = $"{paramName}_";
}
sb.Append($"\"{attr.Name}\", {paramName}.dtype, ");
}
else
{
string attrRealName = attr.Name;
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
{
attrRealName = $"{attrRealName}_";
}
sb.Append($"\"{attr.Name}\", {attrRealName}, ");
}
}
}
else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)))
{
bool found = false;
foreach (var arg in op.InputArg)
{
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
{
realArgName = $"{realArgName}_";
}
if (arg.NumberAttr == attr.Name)
{
sb.Append($"\"{attr.Name}\", {realArgName}.Length, ");
found = true;
break;
}
}
}
else
{
sb.Append($"\"{attr.Name}\", {attr.Name}, ");
}
}
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{
sb.Remove(sb.Length - 2, 2);
}
sb.Append("};\n");

sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " +
$"attrs: _attrs, ctx: ctx, name: name);");

sb.Append("if(_execute.must_record_gradient())\n{\n");

sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);");

sb.AppendLine("}"); // if

if (outputArgsCount == 0)
{
sb.AppendLine("return null;");
}
else if (outputArgsCount == 1)
{
sb.AppendLine("return _result[0];");
}
else
{
sb.AppendLine("return _result;");
}

sb.AppendLine("}"); // body
}

public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb)
{
foreach (var arg in op.InputArg)
{
string argName = arg.Name;
var token = SyntaxFactory.ParseToken(argName);
if (token.IsKeyword())
{
argName = $"{argName}_";
}
if (!string.IsNullOrEmpty(arg.NumberAttr))
{
sb.Append($"Tensors {argName}, ");
}
else
{
sb.Append($"Tensor {argName}, ");
}
}
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, (typeStr, _)) in attrValueDic)
{
var token = SyntaxFactory.ParseToken(key);
string realKey = key;
if (token.IsKeyword())
{
realKey += "_";
}
sb.Append($"{typeStr} {realKey}, ");
}
sb.Append($"string name, Context ctx");
}

public void AppendOpHelperCall(OpDef op, StringBuilder sb)
{
sb.AppendLine("Dictionary<string, object> keywords = new();");
foreach (var arg in op.InputArg)
{
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
{
realArgName += "_";
}
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
}
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, _) in attrValueDic)
{
sb.Append($"keywords[\"{key}\"] = {key};");
}
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
}

// key, (type string, default value)
public Dictionary<string, (string, string)> GetAttrsDefaultValue(OpDef op)
{
Dictionary<string, (string, string)> dic = new();
foreach (var attr in op.Attr)
{
if (attr.Type == "type")
{
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name);
if (!found)
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype());
string enumPath = typeof(TF_DataType).Name + "." + name;
dic[attr.Name] = ("TF_DataType", enumPath);
}
else
{
dic[attr.Name] = ("TF_DataType", "NOVALUE");
}
}
}
else if (attr.Type == "int")
{
if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))
{
continue;
}
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I)
{
dic[attr.Name] = ("int", attr.DefaultValue.I.ToString());
}
else
{
dic[attr.Name] = ("int", "0");
}
}
else if (attr.Type == "float")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F)
{
dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f");
}
else
{
dic[attr.Name] = ("float", "NOVALUE");
}
}
else if (attr.Type == "string")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{
dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"");
}
else
{
dic[attr.Name] = ("string", "NOVALUE");
}
}
else if (attr.Type == "bool")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B)
{
dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower());
}
else
{
dic[attr.Name] = ("bool", "NOVALUE");
}
}
else if (attr.Type == "shape")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape)
{
dic[attr.Name] = ("Shape", $"null");
}
else
{
dic[attr.Name] = ("Shape", "NOVALUE");
}
}
else if (attr.Type == "list(type)")
{
dic[attr.Name] = ("TF_DataType[]", "NOVALUE");
}
else if (attr.Type == "list(shape)")
{
dic[attr.Name] = ("Shape[]", "NOVALUE");
}
else if (attr.Type == "list(string)")
{
dic[attr.Name] = ("string[]", "NOVALUE");
}
else if (attr.Type == "list(int)")
{
dic[attr.Name] = ("int[]", "NOVALUE");
}
else if (attr.Type == "list(float)")
{
dic[attr.Name] = ("float[]", "NOVALUE");
}
else if (attr.Type == "func")
{
dic[attr.Name] = ("Func<Tensors, Tensors>", "NOVALUE");
}
else if (attr.Type == "list(func)")
{
dic[attr.Name] = ("Func<Tensors, Tensors>[]", "NOVALUE");
}
else if (attr.Type == "tensor")
{
dic[attr.Name] = ("TensorProto", "NOVALUE");
}
else
{
throw new NotImplementedException();
}
}
return dic;
}

private static bool HasRefArgs(OpDef op)
{
return op.InputArg.Any(x => x.IsRef);
}
}
}

+ 80
- 0
Tensorflow.CodeGen/GenOpsWriter.cs View File

@@ -0,0 +1,80 @@
using Protobuf.Text;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.CodeGen
{
public class GenOpsWriter
{
private string _basePath;
private Dictionary<string, OpDef> _opMap;
private OpClassifier _opClassifier;
private FunctionGenerator _g = new();

public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename)
{
_basePath = basePath;

var opDefs = ReadAllOpDefs(opDefFilename);
_opMap = opDefs.Op.ToDictionary(
x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x);
_opClassifier = new OpClassifier(pythonFilesDirectory);
}

public void WriteAll()
{
foreach(var (target, set) in _opClassifier.OpSet)
{
StringBuilder sb = new StringBuilder();

// Write file header.
sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/");
sb.AppendLine();

// Add commonly used namespaces.
sb.AppendLine("using Tensorflow.Eager;");
sb.AppendLine("using Tensorflow.Contexts;");
sb.AppendLine("using static Tensorflow.Binding;");
sb.AppendLine();

// Specify the namespace
sb.AppendLine("namespace Tensorflow;");
sb.AppendLine();

// Write class name
sb.AppendLine($"internal static class {target}");
sb.AppendLine("{");

foreach(var funcName in set)
{
if(_opMap.ContainsKey(funcName))
{
var opDef = _opMap[funcName];
_g.AppendFunction(opDef, sb);
}
else if (funcName.StartsWith("_"))
{
var opDef = _opMap[funcName.Substring(1)];
_g.AppendFunction(opDef, sb);
}
}

// Close class scope.
sb.AppendLine("}");

string fullFilePath = Path.Combine(_basePath, $"{target}.cs");
File.WriteAllText(fullFilePath, sb.ToString());
}
}

private OpList ReadAllOpDefs(string path)
{
var text = File.ReadAllText(path);
var opDefs = OpList.Parser.ParseText(text);
return opDefs;
}
}
}

+ 39
- 0
Tensorflow.CodeGen/OpClassifier.cs View File

@@ -0,0 +1,39 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Text.RegularExpressions;

namespace Tensorflow.CodeGen
{
public class OpClassifier
{
private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$";
private static readonly string _pythonFunctionPattern = @"def\s+(\w+)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*\w+\s*=None\s*\):";
private Dictionary<string, HashSet<string>> _opSet = new();
public Dictionary<string, HashSet<string>> OpSet => _opSet;
public OpClassifier(string pythonFileFolder)
{
DirectoryInfo directory = new DirectoryInfo(pythonFileFolder);

foreach (FileInfo file in directory.GetFiles())
{
if (Regex.IsMatch(file.Name, _filenamePattern))
{
string filenamePrefix = file.Name.Split('.')[0];
string content = File.ReadAllText(file.FullName);
var matches = Regex.Matches(content, _pythonFunctionPattern);
foreach(Match match in matches)
{
var funcName = match.Groups[1].Value;
if (!funcName.EndsWith("_eager_fallback"))
{
_opSet.SetDefault(filenamePrefix, new HashSet<string>()).Add(funcName);
}
}
}
}
}
}
}

+ 12
- 0
Tensorflow.CodeGen/Program.cs View File

@@ -0,0 +1,12 @@
using OneOf.Types;
using Protobuf.Text;
using System.Diagnostics;
using System.Text;
using System.Xml.Linq;
using Tensorflow.CodeGen;

GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops",
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt");

writer.WriteAll();

+ 18
- 0
Tensorflow.CodeGen/Tensorflow.CodeGen.csproj View File

@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

</Project>

+ 46
- 0
Tensorflow.CodeGen/Utils.cs View File

@@ -0,0 +1,46 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.CodeGen
{
public static class Utils
{
public static string ConvertToUnderscore(string input)
{
if (string.IsNullOrEmpty(input))
{
return input;
}

StringBuilder result = new StringBuilder();

int state = 0; // the previous char was not lowered.
for (int i = 0; i < input.Length; i++)
{
char current = input[i];

// 首字母不需要添加下划线
if (i != 0 && char.IsUpper(current))
{
if(state == 0)
{
result.Append("_");
state = 1;
}
result.Append(char.ToLower(current));
}
else
{
result.Append(char.ToLower(current));
state = 0;
}
}

return result.ToString();
}
}
}

Loading…
Cancel
Save