Browse Source

fix: revise wrong behaviors of op code generator.

pull/1063/head
Yaohui Liu 2 years ago
parent
commit
2295a04ecd
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 242 additions and 98 deletions
  1. +202
    -82
      Tensorflow.CodeGen/FunctionGenerator.cs
  2. +2
    -2
      Tensorflow.CodeGen/GenOpsWriter.cs
  3. +21
    -9
      Tensorflow.CodeGen/OpClassifier.cs
  4. +2
    -0
      Tensorflow.CodeGen/Program.cs
  5. +3
    -2
      Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
  6. +12
    -3
      Tensorflow.CodeGen/Utils.cs

+ 202
- 82
Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Linq.Expressions;
using System.Reflection.Metadata.Ecma335; using System.Reflection.Metadata.Ecma335;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -16,17 +17,17 @@ namespace Tensorflow.CodeGen
// TODO: add descriptions // TODO: add descriptions
sb.Append("public static "); sb.Append("public static ");
int outputArgsCount = op.OutputArg.Count; int outputArgsCount = op.OutputArg.Count;
if (outputArgsCount > 1)
if (outputArgsCount == 0)
{ {
sb.Append("Tensor[] ");
sb.Append("Operation ");
} }
else if (outputArgsCount == 1)
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
{ {
sb.Append("Tensor "); sb.Append("Tensor ");
} }
else else
{ {
sb.Append("Operation ");
sb.Append("Tensor[] ");
} }
string funcName = Utils.ConvertToUnderscore(op.Name); string funcName = Utils.ConvertToUnderscore(op.Name);
var token = SyntaxFactory.ParseToken(funcName); var token = SyntaxFactory.ParseToken(funcName);
@@ -42,6 +43,17 @@ namespace Tensorflow.CodeGen


// begin to write main body // begin to write main body
sb.AppendLine("var _ctx = tf.Context;"); sb.AppendLine("var _ctx = tf.Context;");

var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues);
// deal with dynamic default values.
foreach(var (name, expr) in dynamicDefaultValues)
{
sb.AppendLine($"if({name} is null)");
sb.AppendLine("{");
sb.AppendLine($"{name} = {expr};");
sb.AppendLine("}");
}

sb.AppendLine("if(_ctx.executing_eagerly()){"); sb.AppendLine("if(_ctx.executing_eagerly()){");


if(HasRefArgs(op)) if(HasRefArgs(op))
@@ -58,7 +70,7 @@ namespace Tensorflow.CodeGen
{ {
sb.AppendLine("return null;"); sb.AppendLine("return null;");
} }
else if (outputArgsCount == 1)
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
{ {
sb.AppendLine("return _fast_path_result[0];"); sb.AppendLine("return _fast_path_result[0];");
} }
@@ -82,6 +94,17 @@ namespace Tensorflow.CodeGen


sb.AppendLine("}"); // if sb.AppendLine("}"); // if


foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string"))
{
if(value != "NOVALUE")
{
sb.AppendLine($"if({name} is null)");
sb.AppendLine("{");
sb.AppendLine($"{name} = {value};");
sb.AppendLine("}");
}
}

// begin to use op helper. // begin to use op helper.
AppendOpHelperCall(op, sb); AppendOpHelperCall(op, sb);
sb.AppendLine("var _result = _op.outputs;"); sb.AppendLine("var _result = _op.outputs;");
@@ -126,7 +149,7 @@ namespace Tensorflow.CodeGen
{ {
sb.AppendLine("return _op;"); sb.AppendLine("return _op;");
} }
else if (outputArgsCount == 1)
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
{ {
sb.AppendLine("return _result[0];"); sb.AppendLine("return _result[0];");
} }
@@ -160,8 +183,8 @@ namespace Tensorflow.CodeGen
sb.Append($"Tensor {argName}, "); sb.Append($"Tensor {argName}, ");
} }
} }
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, (typeStr, value)) in attrValueDic)
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues);
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE"))
{ {
var token = SyntaxFactory.ParseToken(key); var token = SyntaxFactory.ParseToken(key);
string realKey = key; string realKey = key;
@@ -169,21 +192,25 @@ namespace Tensorflow.CodeGen
{ {
realKey += "_"; realKey += "_";
} }
if (value != "NOVALUE")
{
sb.Append($"{typeStr} {realKey} = {value}, ");
}
else
sb.Append($"{typeStr} {realKey}, ");
}
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE"))
{
var token = SyntaxFactory.ParseToken(key);
string realKey = key;
if (token.IsKeyword())
{ {
sb.Append($"{typeStr} {realKey}, ");
realKey += "_";
} }
sb.Append($"{typeStr} {realKey} = {value}, ");
} }
sb.Append($"string? name = null"); sb.Append($"string? name = null");
} }


public void AppendFastPathExecute(OpDef op, StringBuilder sb) public void AppendFastPathExecute(OpDef op, StringBuilder sb)
{ {
sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, ");
sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)");
sb.Append("{ args = new object[]{ ");
foreach (var arg in op.InputArg) foreach (var arg in op.InputArg)
{ {
string attrArgName = arg.Name; string attrArgName = arg.Name;
@@ -193,16 +220,23 @@ namespace Tensorflow.CodeGen
} }
sb.Append($"{attrArgName}, "); sb.Append($"{attrArgName}, ");
} }
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, _) in attrValueDic)
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{ {
sb.Append($"\"{key}\", {key}, ");
sb.Remove(sb.Length - 2, 2);
}

sb.Append("}, attrs = new Dictionary<string, object>(){ ");
var attrValueDic = GetAttrsDefaultValue(op, out var _);
foreach (var (key, _, _) in attrValueDic)
{
sb.Append($"[\"{key}\"] = {key}, ");
} }

if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{ {
sb.Remove(sb.Length - 2, 2); sb.Remove(sb.Length - 2, 2);
} }
sb.Append("));\n");
sb.Append("}});\n");
} }


public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) public void AppendEagerFallbackCall(OpDef op, StringBuilder sb)
@@ -218,8 +252,8 @@ namespace Tensorflow.CodeGen
} }
sb.Append($"{inputArgRealName}, "); sb.Append($"{inputArgRealName}, ");
} }
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, _) in attrValueDic)
var attrValueDic = GetAttrsDefaultValue(op, out var _);
foreach (var (key, _, _) in attrValueDic)
{ {
string keyRealName = key; string keyRealName = key;
if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) if (SyntaxFactory.ParseToken(keyRealName).IsKeyword())
@@ -233,11 +267,19 @@ namespace Tensorflow.CodeGen


public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb)
{ {
sb.Append("public static Tensor");
sb.Append("public static ");
int outputArgsCount = op.OutputArg.Count; int outputArgsCount = op.OutputArg.Count;
if (outputArgsCount > 1)
if (outputArgsCount == 0)
{
sb.Append("Operation ");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
{
sb.Append("Tensor ");
}
else
{ {
sb.Append("[]");
sb.Append("Tensor[] ");
} }
string opName = op.Name; string opName = op.Name;
string funcName = Utils.ConvertToUnderscore(op.Name); string funcName = Utils.ConvertToUnderscore(op.Name);
@@ -254,24 +296,47 @@ namespace Tensorflow.CodeGen
return; return;
} }


sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
foreach (var arg in op.InputArg)
if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr)))
{ {
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
sb.AppendLine("List<Tensor> _inputs_flat_list = new();");
foreach (var arg in op.InputArg)
{ {
realArgName = $"{realArgName}_";
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
{
realArgName = $"{realArgName}_";
}
if (string.IsNullOrEmpty(arg.NumberAttr))
{
sb.AppendLine($"_inputs_flat_list.Add({realArgName});");
}
else
{
sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});");
}
} }
sb.Append($"{realArgName}, ");
sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();");
} }
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
else
{ {
sb.Remove(sb.Length - 2, 2);
sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
foreach (var arg in op.InputArg)
{
string realArgName = arg.Name;
if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
{
realArgName = $"{realArgName}_";
}
sb.Append($"{realArgName}, ");
}
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
{
sb.Remove(sb.Length - 2, 2);
}
sb.Append("};\n");
} }
sb.Append("};\n");


sb.Append("object[] _attrs = new object[]{"); sb.Append("object[] _attrs = new object[]{");
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var attr in op.Attr) foreach (var attr in op.Attr)
{ {
if (attr.Type == "type") if (attr.Type == "type")
@@ -293,27 +358,15 @@ namespace Tensorflow.CodeGen
} }
if (!found) if (!found)
{ {
if (attr.Name.StartsWith("T") && attr.Name.Length > 1)
{
string paramName = attr.Name.Substring(1);
if (SyntaxFactory.ParseToken(paramName).IsKeyword())
{
paramName = $"{paramName}_";
}
sb.Append($"\"{attr.Name}\", {paramName}.dtype, ");
}
else
string attrRealName = attr.Name;
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
{ {
string attrRealName = attr.Name;
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
{
attrRealName = $"{attrRealName}_";
}
sb.Append($"\"{attr.Name}\", {attrRealName}, ");
attrRealName = $"{attrRealName}_";
} }
sb.Append($"\"{attr.Name}\", {attrRealName}, ");
} }
} }
else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)))
else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name))
{ {
bool found = false; bool found = false;
foreach (var arg in op.InputArg) foreach (var arg in op.InputArg)
@@ -355,7 +408,7 @@ namespace Tensorflow.CodeGen
{ {
sb.AppendLine("return null;"); sb.AppendLine("return null;");
} }
else if (outputArgsCount == 1)
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
{ {
sb.AppendLine("return _result[0];"); sb.AppendLine("return _result[0];");
} }
@@ -386,8 +439,8 @@ namespace Tensorflow.CodeGen
sb.Append($"Tensor {argName}, "); sb.Append($"Tensor {argName}, ");
} }
} }
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, (typeStr, _)) in attrValueDic)
var attrValueDic = GetAttrsDefaultValue(op, out var _);
foreach (var (key, typeStr, _) in attrValueDic)
{ {
var token = SyntaxFactory.ParseToken(key); var token = SyntaxFactory.ParseToken(key);
string realKey = key; string realKey = key;
@@ -412,18 +465,19 @@ namespace Tensorflow.CodeGen
} }
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
} }
var attrValueDic = GetAttrsDefaultValue(op);
foreach (var (key, _) in attrValueDic)
var attrValueDic = GetAttrsDefaultValue(op, out var _);
foreach (var (key, _, _) in attrValueDic)
{ {
sb.Append($"keywords[\"{key}\"] = {key};");
sb.AppendLine($"keywords[\"{key}\"] = {key};");
} }
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
} }


// key, (type string, default value)
public Dictionary<string, (string, string)> GetAttrsDefaultValue(OpDef op)
// name, type string, default value
public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues)
{ {
Dictionary<string, (string, string)> dic = new();
dynamicDefaultValues = new();
List<(string, string, string)> res = new();
foreach (var attr in op.Attr) foreach (var attr in op.Attr)
{ {
if (attr.Type == "type") if (attr.Type == "type")
@@ -435,111 +489,177 @@ namespace Tensorflow.CodeGen
{ {
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype());
string enumPath = typeof(TF_DataType).Name + "." + name; string enumPath = typeof(TF_DataType).Name + "." + name;
dic[attr.Name] = ("TF_DataType", enumPath);
res.Add((attr.Name, "TF_DataType", enumPath));
} }
else else
{ {
dic[attr.Name] = ("TF_DataType", "NOVALUE");
res.Add((attr.Name, "TF_DataType", "NOVALUE"));
} }
} }
} }
else if (attr.Type == "int") else if (attr.Type == "int")
{ {
if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))
if(op.InputArg.Any(x => x.NumberAttr == attr.Name))
{ {
continue; continue;
} }
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I)
{ {
dic[attr.Name] = ("int", attr.DefaultValue.I.ToString());
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString()));
} }
else else
{ {
dic[attr.Name] = ("int", "0");
res.Add((attr.Name, "int", "0"));
} }
} }
else if (attr.Type == "float") else if (attr.Type == "float")
{ {
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F)
{ {
dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f");
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f"));
} }
else else
{ {
dic[attr.Name] = ("float", "NOVALUE");
res.Add((attr.Name, "float", "NOVALUE"));
} }
} }
else if (attr.Type == "string") else if (attr.Type == "string")
{ {
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{ {
dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"");
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""));
} }
else else
{ {
dic[attr.Name] = ("string", "NOVALUE");
res.Add((attr.Name, "string", "NOVALUE"));
} }
} }
else if (attr.Type == "bool") else if (attr.Type == "bool")
{ {
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B)
{ {
dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower());
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower()));
} }
else else
{ {
dic[attr.Name] = ("bool", "NOVALUE");
res.Add((attr.Name, "bool", "NOVALUE"));
} }
} }
else if (attr.Type == "shape") else if (attr.Type == "shape")
{ {
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape)
{ {
dic[attr.Name] = ("Shape", $"null");
if (attr.DefaultValue.Shape.UnknownRank)
{
res.Add((attr.Name, "Shape", $"null"));
}
else
{
Shape shape = new Shape(attr.DefaultValue.Shape);
string expression = $"new Shape({string.Join(", ", shape.dims)})";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "Shape", $"null"));
}
} }
else else
{ {
dic[attr.Name] = ("Shape", "NOVALUE");
res.Add((attr.Name, "Shape", "NOVALUE"));
} }
} }
else if (attr.Type == "list(type)") else if (attr.Type == "list(type)")
{ {
dic[attr.Name] = ("TF_DataType[]", "NOVALUE");
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
List<TF_DataType> values = new();
foreach (var value in attr.DefaultValue.List.Type)
{
values.Add(value.as_tf_dtype());
}
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "TF_DataType[]", $"null"));
}
else
{
res.Add((attr.Name, "TF_DataType[]", "NOVALUE"));
}
} }
else if (attr.Type == "list(shape)") else if (attr.Type == "list(shape)")
{ {
dic[attr.Name] = ("Shape[]", "NOVALUE");
res.Add((attr.Name, "Shape[]", "NOVALUE"));
} }
else if (attr.Type == "list(string)") else if (attr.Type == "list(string)")
{ {
dic[attr.Name] = ("string[]", "NOVALUE");
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{
List<string> values = new();
foreach (var value in attr.DefaultValue.List.S)
{
values.Add(value.ToStringUtf8());
}
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "string[]", $"null"));
}
else
{
res.Add((attr.Name, "string[]", "NOVALUE"));
}
} }
else if (attr.Type == "list(int)") else if (attr.Type == "list(int)")
{ {
dic[attr.Name] = ("int[]", "NOVALUE");
if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<int> values = new();
foreach(var value in attr.DefaultValue.List.I)
{
values.Add((int)value);
}
string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "int[]", $"null"));
}
else
{
res.Add((attr.Name, "int[]", "NOVALUE"));
}
} }
else if (attr.Type == "list(float)") else if (attr.Type == "list(float)")
{ {
dic[attr.Name] = ("float[]", "NOVALUE");
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<float> values = new();
foreach (var value in attr.DefaultValue.List.F)
{
values.Add(value);
}
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "float[]", $"null"));
}
else
{
res.Add((attr.Name, "float[]", "NOVALUE"));
}
} }
else if (attr.Type == "func") else if (attr.Type == "func")
{ {
dic[attr.Name] = ("Func<Tensors, Tensors>", "NOVALUE");
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE"));
} }
else if (attr.Type == "list(func)") else if (attr.Type == "list(func)")
{ {
dic[attr.Name] = ("Func<Tensors, Tensors>[]", "NOVALUE");
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE"));
} }
else if (attr.Type == "tensor") else if (attr.Type == "tensor")
{ {
dic[attr.Name] = ("TensorProto", "NOVALUE");
res.Add((attr.Name, "TensorProto", "NOVALUE"));
} }
else else
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
} }
return dic;
return res;
} }


private static bool HasRefArgs(OpDef op) private static bool HasRefArgs(OpDef op)


+ 2
- 2
Tensorflow.CodeGen/GenOpsWriter.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.CodeGen
var opDefs = ReadAllOpDefs(opDefFilename); var opDefs = ReadAllOpDefs(opDefFilename);
_opMap = opDefs.Op.ToDictionary( _opMap = opDefs.Op.ToDictionary(
x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x);
_opClassifier = new OpClassifier(pythonFilesDirectory);
_opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name)));
} }


public void WriteAll() public void WriteAll()
@@ -45,7 +45,7 @@ namespace Tensorflow.CodeGen
sb.AppendLine(); sb.AppendLine();


// Write class name // Write class name
sb.AppendLine($"internal static class {target}");
sb.AppendLine($"public static class {target}");
sb.AppendLine("{"); sb.AppendLine("{");


foreach(var funcName in set) foreach(var funcName in set)


+ 21
- 9
Tensorflow.CodeGen/OpClassifier.cs View File

@@ -10,27 +10,39 @@ namespace Tensorflow.CodeGen
public class OpClassifier public class OpClassifier
{ {
private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$"; private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$";
private static readonly string _pythonFunctionPattern = @"def\s+(\w+)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*\w+\s*=None\s*\):";
private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):";
private Dictionary<string, HashSet<string>> _opSet = new(); private Dictionary<string, HashSet<string>> _opSet = new();
public Dictionary<string, HashSet<string>> OpSet => _opSet; public Dictionary<string, HashSet<string>> OpSet => _opSet;
public OpClassifier(string pythonFileFolder)
public OpClassifier(string pythonFileFolder, IEnumerable<string> funcNames)
{ {
DirectoryInfo directory = new DirectoryInfo(pythonFileFolder); DirectoryInfo directory = new DirectoryInfo(pythonFileFolder);


Dictionary<string, string> fileContentMap = new();
foreach (FileInfo file in directory.GetFiles()) foreach (FileInfo file in directory.GetFiles())
{ {
if (Regex.IsMatch(file.Name, _filenamePattern)) if (Regex.IsMatch(file.Name, _filenamePattern))
{ {
Console.WriteLine(file.Name);
string filenamePrefix = file.Name.Split('.')[0]; string filenamePrefix = file.Name.Split('.')[0];
string content = File.ReadAllText(file.FullName); string content = File.ReadAllText(file.FullName);
var matches = Regex.Matches(content, _pythonFunctionPattern);
foreach(Match match in matches)
fileContentMap[filenamePrefix] = content;
}
}

foreach(var funcName in funcNames)
{
Console.WriteLine(funcName);
string funcPattern = @$"^def\s+{funcName}\(";
string fallbackFuncPattern = @$"^def\s+{funcName}_eager_fallback\(";
foreach (var (target, content) in fileContentMap)
{
if(content.Contains($"def {funcName}") && content.Contains($"def {funcName}_eager_fallback"))
{
_opSet.SetDefault(target, new HashSet<string>()).Add(funcName);
}
else if (content.Contains($"def _{funcName}") && content.Contains($"def _{funcName}_eager_fallback"))
{ {
var funcName = match.Groups[1].Value;
if (!funcName.EndsWith("_eager_fallback"))
{
_opSet.SetDefault(filenamePrefix, new HashSet<string>()).Add(funcName);
}
_opSet.SetDefault(target, new HashSet<string>()).Add(funcName);
} }
} }
} }


+ 2
- 0
Tensorflow.CodeGen/Program.cs View File

@@ -5,6 +5,8 @@ using System.Text;
using System.Xml.Linq; using System.Xml.Linq;
using Tensorflow.CodeGen; using Tensorflow.CodeGen;


//Console.WriteLine(Utils.ConvertToUnderscore("LRN"));

GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops",
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt");


+ 3
- 2
Tensorflow.CodeGen/Tensorflow.CodeGen.csproj View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">


<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
@@ -9,10 +9,11 @@


<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
<PackageReference Include="TensorFlow.NET" Version="0.100.5" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
<ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" />
</ItemGroup> </ItemGroup>


</Project> </Project>

+ 12
- 3
Tensorflow.CodeGen/Utils.cs View File

@@ -18,15 +18,24 @@ namespace Tensorflow.CodeGen


StringBuilder result = new StringBuilder(); StringBuilder result = new StringBuilder();


int state = 0; // the previous char was not lowered.
int state = 1; // the previous char was not lowered.
for (int i = 0; i < input.Length; i++) for (int i = 0; i < input.Length; i++)
{ {
char current = input[i]; char current = input[i];


// 首字母不需要添加下划线 // 首字母不需要添加下划线
if (i != 0 && char.IsUpper(current))
if (char.IsUpper(current))
{ {
if(state == 0)
if(i > 0)
{
char pre = input[i - 1];
if (char.IsDigit(pre))
{
result.Append(char.ToLower(current));
continue;
}
}
if (state == 0)
{ {
result.Append("_"); result.Append("_");
state = 1; state = 1;


Loading…
Cancel
Save