@@ -0,0 +1,263 @@ | |||||
using Microsoft.CodeAnalysis.CSharp; | |||||
using Protobuf.Text; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Reflection.Metadata.Ecma335; | |||||
using System.Text; | |||||
using System.Text.RegularExpressions; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.CodeGen | |||||
{ | |||||
public class DescriptionGenerator | |||||
{ | |||||
private static readonly string replaceStrInner = "~~%~~"; | |||||
private static readonly string replaceStrInnerQuotationMarks = "^%^"; | |||||
Dictionary<string, Dictionary<string, string>> _opDescriptions = new Dictionary<string, Dictionary<string, string>>(); | |||||
Dictionary<string, OpDef> _opDescriptionDefs = new Dictionary<string, OpDef>(); | |||||
public DescriptionGenerator(string apiDefDirectory) | |||||
{ | |||||
DirectoryInfo directory = new DirectoryInfo(apiDefDirectory); | |||||
int errors = 0; | |||||
foreach (FileInfo file in directory.GetFiles()) | |||||
{ | |||||
string target = file.Name.Split('.')[0].Split('_').Last(); | |||||
OpDef op = null; | |||||
try | |||||
{ | |||||
op = ReadOpDefs(file.FullName).Op[0]; | |||||
} | |||||
catch | |||||
{ | |||||
errors++; | |||||
continue; | |||||
} | |||||
_opDescriptionDefs[target] = op; | |||||
_opDescriptions[target] = new Dictionary<string, string>(); | |||||
foreach (var arg in op.InputArg) | |||||
{ | |||||
string argName = arg.Name; | |||||
var token = SyntaxFactory.ParseToken(argName); | |||||
if (token.IsKeyword()) | |||||
{ | |||||
argName = $"{argName}_"; | |||||
} | |||||
_opDescriptions[target][argName] = arg.Description ?? ""; | |||||
} | |||||
foreach (var arg in op.Attr) | |||||
{ | |||||
var token = SyntaxFactory.ParseToken(arg.Name); | |||||
string realKey = arg.Name; | |||||
if (token.IsKeyword()) | |||||
{ | |||||
realKey += "_"; | |||||
} | |||||
_opDescriptions[target][realKey] = arg.Description ?? ""; | |||||
} | |||||
_opDescriptions[target]["SUMMARY"] = op.Summary ?? ""; | |||||
_opDescriptions[target]["DESC"] = op.Description ?? ""; | |||||
} | |||||
Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " + | |||||
$"the failed files number is large, or ignore it."); | |||||
} | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="sb"></param> | |||||
public void AppendDescription(OpDef fullOp, StringBuilder sb) | |||||
{ | |||||
var opName = fullOp.Name; | |||||
if(_opDescriptions.TryGetValue(opName, out var op)) | |||||
{ | |||||
var def = _opDescriptionDefs[opName]; | |||||
sb.AppendLine("/// <summary>"); | |||||
sb.AppendLine($"/// {op["SUMMARY"]}"); | |||||
sb.AppendLine("/// </summary>"); | |||||
string totalDesc = op["DESC"]; | |||||
if (!string.IsNullOrEmpty(totalDesc)) | |||||
{ | |||||
totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\""); | |||||
sb.AppendLine("/// <remarks>"); | |||||
string[] lines = totalDesc.Split(replaceStrInner); | |||||
foreach (var line in lines) | |||||
{ | |||||
sb.AppendLine($"/// {line}"); | |||||
} | |||||
sb.AppendLine("/// </remarks>"); | |||||
} | |||||
var argNames = GetInputArgNames(fullOp); | |||||
foreach (var argName in argNames) | |||||
{ | |||||
if(op.TryGetValue(argName, out var desc)) | |||||
{ | |||||
desc = desc.Replace(replaceStrInnerQuotationMarks, "\""); | |||||
string[] lines = desc.Split(replaceStrInner); | |||||
sb.AppendLine($"/// <param name=\"{argName}\">"); | |||||
foreach (var line in lines) | |||||
{ | |||||
sb.AppendLine($"/// {line}"); | |||||
} | |||||
sb.AppendLine("/// </param>"); | |||||
} | |||||
else | |||||
{ | |||||
sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||||
} | |||||
} | |||||
List<string> returnValueDescs = new(); | |||||
foreach (var arg in def.OutputArg) | |||||
{ | |||||
if (!string.IsNullOrEmpty(arg.Description)) | |||||
{ | |||||
returnValueDescs.Add($"{arg.Name}: {arg.Description}"); | |||||
} | |||||
} | |||||
string returnValueDesc = ""; | |||||
if (returnValueDescs.Count > 0) | |||||
{ | |||||
returnValueDesc = string.Join(" && ", returnValueDescs); | |||||
} | |||||
sb.AppendLine($"/// <returns>{returnValueDesc}</returns>"); | |||||
} | |||||
else | |||||
{ | |||||
sb.AppendLine("/// <summary>"); | |||||
sb.AppendLine($"///"); | |||||
sb.AppendLine("/// </summary>"); | |||||
var argNames = GetInputArgNames(fullOp); | |||||
foreach (var argName in argNames) | |||||
{ | |||||
sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||||
} | |||||
sb.AppendLine($"/// <returns></returns>"); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="op"> | |||||
/// </param> | |||||
/// <returns></returns> | |||||
/// <remarks></remarks> | |||||
public List<string> GetInputArgNames(OpDef op) | |||||
{ | |||||
List<string> names = new(); | |||||
foreach (var arg in op.InputArg) | |||||
{ | |||||
string argName = arg.Name; | |||||
var token = SyntaxFactory.ParseToken(argName); | |||||
if (token.IsKeyword()) | |||||
{ | |||||
argName = $"{argName}_"; | |||||
} | |||||
names.Add(argName); | |||||
} | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
foreach (var (key, typeStr, value) in attrValueDic) | |||||
{ | |||||
var token = SyntaxFactory.ParseToken(key); | |||||
string realKey = key; | |||||
if (token.IsKeyword()) | |||||
{ | |||||
realKey += "_"; | |||||
} | |||||
names.Add(realKey); | |||||
} | |||||
return names; | |||||
} | |||||
private static OpList ReadOpDefs(string path) | |||||
{ | |||||
var text = File.ReadAllText(path); | |||||
text = RemoveLintTags(text); | |||||
text = PreProcessText(text); | |||||
string pattern = @"<<END([\s\S]*?)END"; | |||||
// 定义用于替换的字符串 | |||||
string replaceStrPrefix = "\""; | |||||
string replaceStrSuffix = "\""; | |||||
// 将匹配到的文本段全部替换 | |||||
string replacedText = Regex.Replace(text, pattern, match => { | |||||
string matchedText = match.Value; | |||||
string innerText = match.Groups[1].Value; | |||||
innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks) | |||||
.Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符 | |||||
return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾 | |||||
}, RegexOptions.Multiline); | |||||
var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse<OpList>(replacedText); | |||||
return opDefs; | |||||
} | |||||
static string PreProcessText(string input) | |||||
{ | |||||
int depth = 0; | |||||
int endBlockDepth = -1; | |||||
StringBuilder sb = new StringBuilder(); | |||||
for (int i = 0; i < input.Length; i++) | |||||
{ | |||||
char c = input[i]; | |||||
if (c == '{') | |||||
{ | |||||
depth++; | |||||
sb.Append(c); | |||||
} | |||||
else if (c == '}') | |||||
{ | |||||
if (depth == endBlockDepth) | |||||
{ | |||||
sb.Append("END\n"); | |||||
endBlockDepth = -1; | |||||
} | |||||
sb.Append(c); | |||||
depth--; | |||||
} | |||||
else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "<<END") | |||||
{ | |||||
endBlockDepth = depth; | |||||
sb.Append("<<END"); | |||||
i += 4; | |||||
} | |||||
else if (c == 'E' && i + 3 < input.Length && input.Substring(i, 3) == "END") | |||||
{ | |||||
endBlockDepth = -1; | |||||
sb.Append("END"); | |||||
i += 2; | |||||
} | |||||
else | |||||
{ | |||||
sb.Append(c); | |||||
} | |||||
} | |||||
string output = sb.ToString(); | |||||
return output; | |||||
} | |||||
static string RemoveLintTags(string input) | |||||
{ | |||||
string[] lines = input.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None); | |||||
StringBuilder sb = new StringBuilder(); | |||||
foreach (string line in lines) | |||||
{ | |||||
if (!line.TrimStart().StartsWith("# LINT")) | |||||
{ | |||||
sb.AppendLine(line); | |||||
} | |||||
} | |||||
return sb.ToString().TrimEnd(); | |||||
} | |||||
} | |||||
} |
@@ -44,7 +44,7 @@ namespace Tensorflow.CodeGen | |||||
// begin to write main body | // begin to write main body | ||||
sb.AppendLine("var _ctx = tf.Context;"); | sb.AppendLine("var _ctx = tf.Context;"); | ||||
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
// deal with dynamic default values. | // deal with dynamic default values. | ||||
foreach(var (name, expr) in dynamicDefaultValues) | foreach(var (name, expr) in dynamicDefaultValues) | ||||
{ | { | ||||
@@ -183,7 +183,7 @@ namespace Tensorflow.CodeGen | |||||
sb.Append($"Tensor {argName}, "); | sb.Append($"Tensor {argName}, "); | ||||
} | } | ||||
} | } | ||||
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) | foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) | ||||
{ | { | ||||
var token = SyntaxFactory.ParseToken(key); | var token = SyntaxFactory.ParseToken(key); | ||||
@@ -226,7 +226,7 @@ namespace Tensorflow.CodeGen | |||||
} | } | ||||
sb.Append("}, attrs = new Dictionary<string, object>(){ "); | sb.Append("}, attrs = new Dictionary<string, object>(){ "); | ||||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
foreach (var (key, _, _) in attrValueDic) | foreach (var (key, _, _) in attrValueDic) | ||||
{ | { | ||||
sb.Append($"[\"{key}\"] = {key}, "); | sb.Append($"[\"{key}\"] = {key}, "); | ||||
@@ -252,7 +252,7 @@ namespace Tensorflow.CodeGen | |||||
} | } | ||||
sb.Append($"{inputArgRealName}, "); | sb.Append($"{inputArgRealName}, "); | ||||
} | } | ||||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
foreach (var (key, _, _) in attrValueDic) | foreach (var (key, _, _) in attrValueDic) | ||||
{ | { | ||||
string keyRealName = key; | string keyRealName = key; | ||||
@@ -439,7 +439,7 @@ namespace Tensorflow.CodeGen | |||||
sb.Append($"Tensor {argName}, "); | sb.Append($"Tensor {argName}, "); | ||||
} | } | ||||
} | } | ||||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
foreach (var (key, typeStr, _) in attrValueDic) | foreach (var (key, typeStr, _) in attrValueDic) | ||||
{ | { | ||||
var token = SyntaxFactory.ParseToken(key); | var token = SyntaxFactory.ParseToken(key); | ||||
@@ -465,7 +465,7 @@ namespace Tensorflow.CodeGen | |||||
} | } | ||||
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); | sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); | ||||
} | } | ||||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
foreach (var (key, _, _) in attrValueDic) | foreach (var (key, _, _) in attrValueDic) | ||||
{ | { | ||||
sb.AppendLine($"keywords[\"{key}\"] = {key};"); | sb.AppendLine($"keywords[\"{key}\"] = {key};"); | ||||
@@ -473,195 +473,6 @@ namespace Tensorflow.CodeGen | |||||
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); | sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); | ||||
} | } | ||||
// name, type string, default value | |||||
public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||||
{ | |||||
dynamicDefaultValues = new(); | |||||
List<(string, string, string)> res = new(); | |||||
foreach (var attr in op.Attr) | |||||
{ | |||||
if (attr.Type == "type") | |||||
{ | |||||
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||||
if (!found) | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
{ | |||||
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||||
string enumPath = typeof(TF_DataType).Name + "." + name; | |||||
res.Add((attr.Name, "TF_DataType", enumPath)); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||||
} | |||||
} | |||||
} | |||||
else if (attr.Type == "int") | |||||
{ | |||||
if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||||
{ | |||||
continue; | |||||
} | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||||
{ | |||||
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "int", "0")); | |||||
} | |||||
} | |||||
else if (attr.Type == "float") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||||
{ | |||||
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "float", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "string") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
{ | |||||
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "string", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "bool") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||||
{ | |||||
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "bool", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "shape") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||||
{ | |||||
if (attr.DefaultValue.Shape.UnknownRank) | |||||
{ | |||||
res.Add((attr.Name, "Shape", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
Shape shape = new Shape(attr.DefaultValue.Shape); | |||||
string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "Shape", $"null")); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "Shape", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(type)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
{ | |||||
List<TF_DataType> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.Type) | |||||
{ | |||||
values.Add(value.as_tf_dtype()); | |||||
} | |||||
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "TF_DataType[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(shape)") | |||||
{ | |||||
res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||||
} | |||||
else if (attr.Type == "list(string)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
{ | |||||
List<string> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.S) | |||||
{ | |||||
values.Add(value.ToStringUtf8()); | |||||
} | |||||
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "string[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "string[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(int)") | |||||
{ | |||||
if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
List<int> values = new(); | |||||
foreach(var value in attr.DefaultValue.List.I) | |||||
{ | |||||
values.Add((int)value); | |||||
} | |||||
string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "int[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "int[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(float)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
List<float> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.F) | |||||
{ | |||||
values.Add(value); | |||||
} | |||||
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "float[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "float[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "func") | |||||
{ | |||||
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||||
} | |||||
else if (attr.Type == "list(func)") | |||||
{ | |||||
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||||
} | |||||
else if (attr.Type == "tensor") | |||||
{ | |||||
res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
private static bool HasRefArgs(OpDef op) | private static bool HasRefArgs(OpDef op) | ||||
{ | { | ||||
return op.InputArg.Any(x => x.IsRef); | return op.InputArg.Any(x => x.IsRef); | ||||
@@ -12,16 +12,18 @@ namespace Tensorflow.CodeGen | |||||
private string _basePath; | private string _basePath; | ||||
private Dictionary<string, OpDef> _opMap; | private Dictionary<string, OpDef> _opMap; | ||||
private OpClassifier _opClassifier; | private OpClassifier _opClassifier; | ||||
private FunctionGenerator _g = new(); | |||||
private FunctionGenerator _fg = new(); | |||||
private DescriptionGenerator _dg; | |||||
public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) | |||||
public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename) | |||||
{ | { | ||||
_basePath = basePath; | _basePath = basePath; | ||||
var opDefs = ReadAllOpDefs(opDefFilename); | |||||
var opDefs = Utils.ReadAllOpDefs(opDefFilename); | |||||
_opMap = opDefs.Op.ToDictionary( | _opMap = opDefs.Op.ToDictionary( | ||||
x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); | |||||
x => Utils.ConvertToUnderscore(x.Name), x => x); | |||||
_opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); | _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); | ||||
_dg = new DescriptionGenerator(apiDefFilesDirectory); | |||||
} | } | ||||
public void WriteAll() | public void WriteAll() | ||||
@@ -53,12 +55,17 @@ namespace Tensorflow.CodeGen | |||||
if(_opMap.ContainsKey(funcName)) | if(_opMap.ContainsKey(funcName)) | ||||
{ | { | ||||
var opDef = _opMap[funcName]; | var opDef = _opMap[funcName]; | ||||
_g.AppendFunction(opDef, sb); | |||||
// write the descriptions. | |||||
_dg.AppendDescription(opDef, sb); | |||||
// write the function body. | |||||
_fg.AppendFunction(opDef, sb); | |||||
} | } | ||||
else if (funcName.StartsWith("_")) | else if (funcName.StartsWith("_")) | ||||
{ | { | ||||
var opDef = _opMap[funcName.Substring(1)]; | var opDef = _opMap[funcName.Substring(1)]; | ||||
_g.AppendFunction(opDef, sb); | |||||
_fg.AppendFunction(opDef, sb); | |||||
} | } | ||||
} | } | ||||
@@ -69,12 +76,5 @@ namespace Tensorflow.CodeGen | |||||
File.WriteAllText(fullFilePath, sb.ToString()); | File.WriteAllText(fullFilePath, sb.ToString()); | ||||
} | } | ||||
} | } | ||||
private OpList ReadAllOpDefs(string path) | |||||
{ | |||||
var text = File.ReadAllText(path); | |||||
var opDefs = OpList.Parser.ParseText(text); | |||||
return opDefs; | |||||
} | |||||
} | } | ||||
} | } |
@@ -5,10 +5,9 @@ using System.Text; | |||||
using System.Xml.Linq; | using System.Xml.Linq; | ||||
using Tensorflow.CodeGen; | using Tensorflow.CodeGen; | ||||
//Console.WriteLine(Utils.ConvertToUnderscore("LRN")); | |||||
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | ||||
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | ||||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | |||||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | ||||
writer.WriteAll(); | writer.WriteAll(); |
@@ -9,11 +9,11 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.100.5" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" /> | <ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" /> | ||||
<ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
</Project> | </Project> |
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using Protobuf.Text; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Reflection.Metadata.Ecma335; | using System.Reflection.Metadata.Ecma335; | ||||
@@ -51,5 +52,201 @@ namespace Tensorflow.CodeGen | |||||
return result.ToString(); | return result.ToString(); | ||||
} | } | ||||
public static OpList ReadAllOpDefs(string path) | |||||
{ | |||||
var text = File.ReadAllText(path); | |||||
var opDefs = OpList.Parser.ParseText(text); | |||||
return opDefs; | |||||
} | |||||
// name, type string, default value | |||||
public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||||
{ | |||||
dynamicDefaultValues = new(); | |||||
List<(string, string, string)> res = new(); | |||||
foreach (var attr in op.Attr) | |||||
{ | |||||
if (attr.Type == "type") | |||||
{ | |||||
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||||
if (!found) | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
{ | |||||
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||||
string enumPath = typeof(TF_DataType).Name + "." + name; | |||||
res.Add((attr.Name, "TF_DataType", enumPath)); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||||
} | |||||
} | |||||
} | |||||
else if (attr.Type == "int") | |||||
{ | |||||
if (op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||||
{ | |||||
continue; | |||||
} | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||||
{ | |||||
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "int", "0")); | |||||
} | |||||
} | |||||
else if (attr.Type == "float") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||||
{ | |||||
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "float", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "string") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
{ | |||||
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "string", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "bool") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||||
{ | |||||
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "bool", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "shape") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||||
{ | |||||
if (attr.DefaultValue.Shape.UnknownRank) | |||||
{ | |||||
res.Add((attr.Name, "Shape", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
Shape shape = new Shape(attr.DefaultValue.Shape); | |||||
string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "Shape", $"null")); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "Shape", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(type)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
{ | |||||
List<TF_DataType> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.Type) | |||||
{ | |||||
values.Add(value.as_tf_dtype()); | |||||
} | |||||
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "TF_DataType[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(shape)") | |||||
{ | |||||
res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||||
} | |||||
else if (attr.Type == "list(string)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
{ | |||||
List<string> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.S) | |||||
{ | |||||
values.Add(value.ToStringUtf8()); | |||||
} | |||||
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "string[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "string[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(int)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
List<int> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.I) | |||||
{ | |||||
values.Add((int)value); | |||||
} | |||||
string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "int[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "int[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "list(float)") | |||||
{ | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
List<float> values = new(); | |||||
foreach (var value in attr.DefaultValue.List.F) | |||||
{ | |||||
values.Add(value); | |||||
} | |||||
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "float[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "float[]", "NOVALUE")); | |||||
} | |||||
} | |||||
else if (attr.Type == "func") | |||||
{ | |||||
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||||
} | |||||
else if (attr.Type == "list(func)") | |||||
{ | |||||
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||||
} | |||||
else if (attr.Type == "tensor") | |||||
{ | |||||
res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
} | } | ||||
} | } |