@@ -0,0 +1,263 @@ | |||
using Microsoft.CodeAnalysis.CSharp; | |||
using Protobuf.Text; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Reflection.Metadata.Ecma335; | |||
using System.Text; | |||
using System.Text.RegularExpressions; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.CodeGen | |||
{ | |||
public class DescriptionGenerator | |||
{ | |||
private static readonly string replaceStrInner = "~~%~~"; | |||
private static readonly string replaceStrInnerQuotationMarks = "^%^"; | |||
Dictionary<string, Dictionary<string, string>> _opDescriptions = new Dictionary<string, Dictionary<string, string>>(); | |||
Dictionary<string, OpDef> _opDescriptionDefs = new Dictionary<string, OpDef>(); | |||
public DescriptionGenerator(string apiDefDirectory) | |||
{ | |||
DirectoryInfo directory = new DirectoryInfo(apiDefDirectory); | |||
int errors = 0; | |||
foreach (FileInfo file in directory.GetFiles()) | |||
{ | |||
string target = file.Name.Split('.')[0].Split('_').Last(); | |||
OpDef op = null; | |||
try | |||
{ | |||
op = ReadOpDefs(file.FullName).Op[0]; | |||
} | |||
catch | |||
{ | |||
errors++; | |||
continue; | |||
} | |||
_opDescriptionDefs[target] = op; | |||
_opDescriptions[target] = new Dictionary<string, string>(); | |||
foreach (var arg in op.InputArg) | |||
{ | |||
string argName = arg.Name; | |||
var token = SyntaxFactory.ParseToken(argName); | |||
if (token.IsKeyword()) | |||
{ | |||
argName = $"{argName}_"; | |||
} | |||
_opDescriptions[target][argName] = arg.Description ?? ""; | |||
} | |||
foreach (var arg in op.Attr) | |||
{ | |||
var token = SyntaxFactory.ParseToken(arg.Name); | |||
string realKey = arg.Name; | |||
if (token.IsKeyword()) | |||
{ | |||
realKey += "_"; | |||
} | |||
_opDescriptions[target][realKey] = arg.Description ?? ""; | |||
} | |||
_opDescriptions[target]["SUMMARY"] = op.Summary ?? ""; | |||
_opDescriptions[target]["DESC"] = op.Description ?? ""; | |||
} | |||
Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " + | |||
$"the failed files number is large, or ignore it."); | |||
} | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="sb"></param> | |||
public void AppendDescription(OpDef fullOp, StringBuilder sb) | |||
{ | |||
var opName = fullOp.Name; | |||
if(_opDescriptions.TryGetValue(opName, out var op)) | |||
{ | |||
var def = _opDescriptionDefs[opName]; | |||
sb.AppendLine("/// <summary>"); | |||
sb.AppendLine($"/// {op["SUMMARY"]}"); | |||
sb.AppendLine("/// </summary>"); | |||
string totalDesc = op["DESC"]; | |||
if (!string.IsNullOrEmpty(totalDesc)) | |||
{ | |||
totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\""); | |||
sb.AppendLine("/// <remarks>"); | |||
string[] lines = totalDesc.Split(replaceStrInner); | |||
foreach (var line in lines) | |||
{ | |||
sb.AppendLine($"/// {line}"); | |||
} | |||
sb.AppendLine("/// </remarks>"); | |||
} | |||
var argNames = GetInputArgNames(fullOp); | |||
foreach (var argName in argNames) | |||
{ | |||
if(op.TryGetValue(argName, out var desc)) | |||
{ | |||
desc = desc.Replace(replaceStrInnerQuotationMarks, "\""); | |||
string[] lines = desc.Split(replaceStrInner); | |||
sb.AppendLine($"/// <param name=\"{argName}\">"); | |||
foreach (var line in lines) | |||
{ | |||
sb.AppendLine($"/// {line}"); | |||
} | |||
sb.AppendLine("/// </param>"); | |||
} | |||
else | |||
{ | |||
sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||
} | |||
} | |||
List<string> returnValueDescs = new(); | |||
foreach (var arg in def.OutputArg) | |||
{ | |||
if (!string.IsNullOrEmpty(arg.Description)) | |||
{ | |||
returnValueDescs.Add($"{arg.Name}: {arg.Description}"); | |||
} | |||
} | |||
string returnValueDesc = ""; | |||
if (returnValueDescs.Count > 0) | |||
{ | |||
returnValueDesc = string.Join(" && ", returnValueDescs); | |||
} | |||
sb.AppendLine($"/// <returns>{returnValueDesc}</returns>"); | |||
} | |||
else | |||
{ | |||
sb.AppendLine("/// <summary>"); | |||
sb.AppendLine($"///"); | |||
sb.AppendLine("/// </summary>"); | |||
var argNames = GetInputArgNames(fullOp); | |||
foreach (var argName in argNames) | |||
{ | |||
sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||
} | |||
sb.AppendLine($"/// <returns></returns>"); | |||
} | |||
} | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="op"> | |||
/// </param> | |||
/// <returns></returns> | |||
/// <remarks></remarks> | |||
public List<string> GetInputArgNames(OpDef op) | |||
{ | |||
List<string> names = new(); | |||
foreach (var arg in op.InputArg) | |||
{ | |||
string argName = arg.Name; | |||
var token = SyntaxFactory.ParseToken(argName); | |||
if (token.IsKeyword()) | |||
{ | |||
argName = $"{argName}_"; | |||
} | |||
names.Add(argName); | |||
} | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
foreach (var (key, typeStr, value) in attrValueDic) | |||
{ | |||
var token = SyntaxFactory.ParseToken(key); | |||
string realKey = key; | |||
if (token.IsKeyword()) | |||
{ | |||
realKey += "_"; | |||
} | |||
names.Add(realKey); | |||
} | |||
return names; | |||
} | |||
private static OpList ReadOpDefs(string path) | |||
{ | |||
var text = File.ReadAllText(path); | |||
text = RemoveLintTags(text); | |||
text = PreProcessText(text); | |||
string pattern = @"<<END([\s\S]*?)END"; | |||
// 定义用于替换的字符串 | |||
string replaceStrPrefix = "\""; | |||
string replaceStrSuffix = "\""; | |||
// 将匹配到的文本段全部替换 | |||
string replacedText = Regex.Replace(text, pattern, match => { | |||
string matchedText = match.Value; | |||
string innerText = match.Groups[1].Value; | |||
innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks) | |||
.Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符 | |||
return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾 | |||
}, RegexOptions.Multiline); | |||
var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse<OpList>(replacedText); | |||
return opDefs; | |||
} | |||
static string PreProcessText(string input) | |||
{ | |||
int depth = 0; | |||
int endBlockDepth = -1; | |||
StringBuilder sb = new StringBuilder(); | |||
for (int i = 0; i < input.Length; i++) | |||
{ | |||
char c = input[i]; | |||
if (c == '{') | |||
{ | |||
depth++; | |||
sb.Append(c); | |||
} | |||
else if (c == '}') | |||
{ | |||
if (depth == endBlockDepth) | |||
{ | |||
sb.Append("END\n"); | |||
endBlockDepth = -1; | |||
} | |||
sb.Append(c); | |||
depth--; | |||
} | |||
else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "<<END") | |||
{ | |||
endBlockDepth = depth; | |||
sb.Append("<<END"); | |||
i += 4; | |||
} | |||
else if (c == 'E' && i + 3 < input.Length && input.Substring(i, 3) == "END") | |||
{ | |||
endBlockDepth = -1; | |||
sb.Append("END"); | |||
i += 2; | |||
} | |||
else | |||
{ | |||
sb.Append(c); | |||
} | |||
} | |||
string output = sb.ToString(); | |||
return output; | |||
} | |||
static string RemoveLintTags(string input) | |||
{ | |||
string[] lines = input.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None); | |||
StringBuilder sb = new StringBuilder(); | |||
foreach (string line in lines) | |||
{ | |||
if (!line.TrimStart().StartsWith("# LINT")) | |||
{ | |||
sb.AppendLine(line); | |||
} | |||
} | |||
return sb.ToString().TrimEnd(); | |||
} | |||
} | |||
} |
@@ -44,7 +44,7 @@ namespace Tensorflow.CodeGen | |||
// begin to write main body | |||
sb.AppendLine("var _ctx = tf.Context;"); | |||
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
// deal with dynamic default values. | |||
foreach(var (name, expr) in dynamicDefaultValues) | |||
{ | |||
@@ -183,7 +183,7 @@ namespace Tensorflow.CodeGen | |||
sb.Append($"Tensor {argName}, "); | |||
} | |||
} | |||
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) | |||
{ | |||
var token = SyntaxFactory.ParseToken(key); | |||
@@ -226,7 +226,7 @@ namespace Tensorflow.CodeGen | |||
} | |||
sb.Append("}, attrs = new Dictionary<string, object>(){ "); | |||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
foreach (var (key, _, _) in attrValueDic) | |||
{ | |||
sb.Append($"[\"{key}\"] = {key}, "); | |||
@@ -252,7 +252,7 @@ namespace Tensorflow.CodeGen | |||
} | |||
sb.Append($"{inputArgRealName}, "); | |||
} | |||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
foreach (var (key, _, _) in attrValueDic) | |||
{ | |||
string keyRealName = key; | |||
@@ -439,7 +439,7 @@ namespace Tensorflow.CodeGen | |||
sb.Append($"Tensor {argName}, "); | |||
} | |||
} | |||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
foreach (var (key, typeStr, _) in attrValueDic) | |||
{ | |||
var token = SyntaxFactory.ParseToken(key); | |||
@@ -465,7 +465,7 @@ namespace Tensorflow.CodeGen | |||
} | |||
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); | |||
} | |||
var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
foreach (var (key, _, _) in attrValueDic) | |||
{ | |||
sb.AppendLine($"keywords[\"{key}\"] = {key};"); | |||
@@ -473,195 +473,6 @@ namespace Tensorflow.CodeGen | |||
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); | |||
} | |||
// name, type string, default value | |||
public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||
{ | |||
dynamicDefaultValues = new(); | |||
List<(string, string, string)> res = new(); | |||
foreach (var attr in op.Attr) | |||
{ | |||
if (attr.Type == "type") | |||
{ | |||
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||
if (!found) | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
{ | |||
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||
string enumPath = typeof(TF_DataType).Name + "." + name; | |||
res.Add((attr.Name, "TF_DataType", enumPath)); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||
} | |||
} | |||
} | |||
else if (attr.Type == "int") | |||
{ | |||
if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||
{ | |||
continue; | |||
} | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||
{ | |||
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "int", "0")); | |||
} | |||
} | |||
else if (attr.Type == "float") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||
{ | |||
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "float", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "string") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
{ | |||
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "string", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "bool") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||
{ | |||
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "bool", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "shape") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||
{ | |||
if (attr.DefaultValue.Shape.UnknownRank) | |||
{ | |||
res.Add((attr.Name, "Shape", $"null")); | |||
} | |||
else | |||
{ | |||
Shape shape = new Shape(attr.DefaultValue.Shape); | |||
string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "Shape", $"null")); | |||
} | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "Shape", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(type)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
{ | |||
List<TF_DataType> values = new(); | |||
foreach (var value in attr.DefaultValue.List.Type) | |||
{ | |||
values.Add(value.as_tf_dtype()); | |||
} | |||
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "TF_DataType[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(shape)") | |||
{ | |||
res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||
} | |||
else if (attr.Type == "list(string)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
{ | |||
List<string> values = new(); | |||
foreach (var value in attr.DefaultValue.List.S) | |||
{ | |||
values.Add(value.ToStringUtf8()); | |||
} | |||
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "string[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "string[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(int)") | |||
{ | |||
if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
List<int> values = new(); | |||
foreach(var value in attr.DefaultValue.List.I) | |||
{ | |||
values.Add((int)value); | |||
} | |||
string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "int[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "int[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(float)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
List<float> values = new(); | |||
foreach (var value in attr.DefaultValue.List.F) | |||
{ | |||
values.Add(value); | |||
} | |||
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "float[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "float[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "func") | |||
{ | |||
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||
} | |||
else if (attr.Type == "list(func)") | |||
{ | |||
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||
} | |||
else if (attr.Type == "tensor") | |||
{ | |||
res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
return res; | |||
} | |||
private static bool HasRefArgs(OpDef op) | |||
{ | |||
return op.InputArg.Any(x => x.IsRef); | |||
@@ -12,16 +12,18 @@ namespace Tensorflow.CodeGen | |||
private string _basePath; | |||
private Dictionary<string, OpDef> _opMap; | |||
private OpClassifier _opClassifier; | |||
private FunctionGenerator _g = new(); | |||
private FunctionGenerator _fg = new(); | |||
private DescriptionGenerator _dg; | |||
public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) | |||
public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename) | |||
{ | |||
_basePath = basePath; | |||
var opDefs = ReadAllOpDefs(opDefFilename); | |||
var opDefs = Utils.ReadAllOpDefs(opDefFilename); | |||
_opMap = opDefs.Op.ToDictionary( | |||
x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); | |||
x => Utils.ConvertToUnderscore(x.Name), x => x); | |||
_opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); | |||
_dg = new DescriptionGenerator(apiDefFilesDirectory); | |||
} | |||
public void WriteAll() | |||
@@ -53,12 +55,17 @@ namespace Tensorflow.CodeGen | |||
if(_opMap.ContainsKey(funcName)) | |||
{ | |||
var opDef = _opMap[funcName]; | |||
_g.AppendFunction(opDef, sb); | |||
// write the descriptions. | |||
_dg.AppendDescription(opDef, sb); | |||
// write the function body. | |||
_fg.AppendFunction(opDef, sb); | |||
} | |||
else if (funcName.StartsWith("_")) | |||
{ | |||
var opDef = _opMap[funcName.Substring(1)]; | |||
_g.AppendFunction(opDef, sb); | |||
_fg.AppendFunction(opDef, sb); | |||
} | |||
} | |||
@@ -69,12 +76,5 @@ namespace Tensorflow.CodeGen | |||
File.WriteAllText(fullFilePath, sb.ToString()); | |||
} | |||
} | |||
private OpList ReadAllOpDefs(string path) | |||
{ | |||
var text = File.ReadAllText(path); | |||
var opDefs = OpList.Parser.ParseText(text); | |||
return opDefs; | |||
} | |||
} | |||
} |
@@ -5,10 +5,9 @@ using System.Text; | |||
using System.Xml.Linq; | |||
using Tensorflow.CodeGen; | |||
//Console.WriteLine(Utils.ConvertToUnderscore("LRN")); | |||
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | |||
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | |||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | |||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | |||
writer.WriteAll(); |
@@ -9,11 +9,11 @@ | |||
<ItemGroup> | |||
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.100.5" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" /> | |||
<ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -1,4 +1,5 @@ | |||
using System; | |||
using Protobuf.Text; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Reflection.Metadata.Ecma335; | |||
@@ -51,5 +52,201 @@ namespace Tensorflow.CodeGen | |||
return result.ToString(); | |||
} | |||
public static OpList ReadAllOpDefs(string path) | |||
{ | |||
var text = File.ReadAllText(path); | |||
var opDefs = OpList.Parser.ParseText(text); | |||
return opDefs; | |||
} | |||
// name, type string, default value | |||
public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||
{ | |||
dynamicDefaultValues = new(); | |||
List<(string, string, string)> res = new(); | |||
foreach (var attr in op.Attr) | |||
{ | |||
if (attr.Type == "type") | |||
{ | |||
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||
if (!found) | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
{ | |||
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||
string enumPath = typeof(TF_DataType).Name + "." + name; | |||
res.Add((attr.Name, "TF_DataType", enumPath)); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||
} | |||
} | |||
} | |||
else if (attr.Type == "int") | |||
{ | |||
if (op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||
{ | |||
continue; | |||
} | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||
{ | |||
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "int", "0")); | |||
} | |||
} | |||
else if (attr.Type == "float") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||
{ | |||
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "float", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "string") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
{ | |||
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "string", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "bool") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||
{ | |||
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "bool", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "shape") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||
{ | |||
if (attr.DefaultValue.Shape.UnknownRank) | |||
{ | |||
res.Add((attr.Name, "Shape", $"null")); | |||
} | |||
else | |||
{ | |||
Shape shape = new Shape(attr.DefaultValue.Shape); | |||
string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "Shape", $"null")); | |||
} | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "Shape", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(type)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
{ | |||
List<TF_DataType> values = new(); | |||
foreach (var value in attr.DefaultValue.List.Type) | |||
{ | |||
values.Add(value.as_tf_dtype()); | |||
} | |||
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "TF_DataType[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(shape)") | |||
{ | |||
res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||
} | |||
else if (attr.Type == "list(string)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
{ | |||
List<string> values = new(); | |||
foreach (var value in attr.DefaultValue.List.S) | |||
{ | |||
values.Add(value.ToStringUtf8()); | |||
} | |||
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "string[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "string[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(int)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
List<int> values = new(); | |||
foreach (var value in attr.DefaultValue.List.I) | |||
{ | |||
values.Add((int)value); | |||
} | |||
string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "int[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "int[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "list(float)") | |||
{ | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
List<float> values = new(); | |||
foreach (var value in attr.DefaultValue.List.F) | |||
{ | |||
values.Add(value); | |||
} | |||
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||
dynamicDefaultValues[attr.Name] = expression; | |||
res.Add((attr.Name, "float[]", $"null")); | |||
} | |||
else | |||
{ | |||
res.Add((attr.Name, "float[]", "NOVALUE")); | |||
} | |||
} | |||
else if (attr.Type == "func") | |||
{ | |||
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||
} | |||
else if (attr.Type == "list(func)") | |||
{ | |||
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||
} | |||
else if (attr.Type == "tensor") | |||
{ | |||
res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
return res; | |||
} | |||
} | |||
} |