Browse Source

feat: description generator of op code.

pull/1063/head
Yaohui Liu 2 years ago
parent
commit
28642568a2
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 482 additions and 212 deletions
  1. +263
    -0
      Tensorflow.CodeGen/DescriptionGenerator.cs
  2. +6
    -195
      Tensorflow.CodeGen/FunctionGenerator.cs
  3. +13
    -13
      Tensorflow.CodeGen/GenOpsWriter.cs
  4. +1
    -2
      Tensorflow.CodeGen/Program.cs
  5. +1
    -1
      Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
  6. +198
    -1
      Tensorflow.CodeGen/Utils.cs

+ 263
- 0
Tensorflow.CodeGen/DescriptionGenerator.cs View File

@@ -0,0 +1,263 @@
using Microsoft.CodeAnalysis.CSharp;
using Protobuf.Text;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace Tensorflow.CodeGen
{
public class DescriptionGenerator
{
private static readonly string replaceStrInner = "~~%~~";
private static readonly string replaceStrInnerQuotationMarks = "^%^";
Dictionary<string, Dictionary<string, string>> _opDescriptions = new Dictionary<string, Dictionary<string, string>>();
Dictionary<string, OpDef> _opDescriptionDefs = new Dictionary<string, OpDef>();
public DescriptionGenerator(string apiDefDirectory)
{
DirectoryInfo directory = new DirectoryInfo(apiDefDirectory);

int errors = 0;
foreach (FileInfo file in directory.GetFiles())
{
string target = file.Name.Split('.')[0].Split('_').Last();
OpDef op = null;
try
{
op = ReadOpDefs(file.FullName).Op[0];
}
catch
{
errors++;
continue;
}
_opDescriptionDefs[target] = op;
_opDescriptions[target] = new Dictionary<string, string>();
foreach (var arg in op.InputArg)
{
string argName = arg.Name;
var token = SyntaxFactory.ParseToken(argName);
if (token.IsKeyword())
{
argName = $"{argName}_";
}
_opDescriptions[target][argName] = arg.Description ?? "";
}
foreach (var arg in op.Attr)
{
var token = SyntaxFactory.ParseToken(arg.Name);
string realKey = arg.Name;
if (token.IsKeyword())
{
realKey += "_";
}
_opDescriptions[target][realKey] = arg.Description ?? "";
}
_opDescriptions[target]["SUMMARY"] = op.Summary ?? "";
_opDescriptions[target]["DESC"] = op.Description ?? "";
}
Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " +
$"the failed files number is large, or ignore it.");
}

/// <summary>
///
/// </summary>
/// <param name="op"></param>
/// <param name="sb"></param>
public void AppendDescription(OpDef fullOp, StringBuilder sb)
{
var opName = fullOp.Name;
if(_opDescriptions.TryGetValue(opName, out var op))
{
var def = _opDescriptionDefs[opName];
sb.AppendLine("/// <summary>");
sb.AppendLine($"/// {op["SUMMARY"]}");
sb.AppendLine("/// </summary>");

string totalDesc = op["DESC"];
if (!string.IsNullOrEmpty(totalDesc))
{
totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\"");
sb.AppendLine("/// <remarks>");
string[] lines = totalDesc.Split(replaceStrInner);
foreach (var line in lines)
{
sb.AppendLine($"/// {line}");
}
sb.AppendLine("/// </remarks>");
}

var argNames = GetInputArgNames(fullOp);
foreach (var argName in argNames)
{
if(op.TryGetValue(argName, out var desc))
{
desc = desc.Replace(replaceStrInnerQuotationMarks, "\"");
string[] lines = desc.Split(replaceStrInner);
sb.AppendLine($"/// <param name=\"{argName}\">");
foreach (var line in lines)
{
sb.AppendLine($"/// {line}");
}
sb.AppendLine("/// </param>");
}
else
{
sb.AppendLine($"/// <param name=\"{argName}\"></param>");
}
}

List<string> returnValueDescs = new();
foreach (var arg in def.OutputArg)
{
if (!string.IsNullOrEmpty(arg.Description))
{
returnValueDescs.Add($"{arg.Name}: {arg.Description}");
}
}
string returnValueDesc = "";
if (returnValueDescs.Count > 0)
{
returnValueDesc = string.Join(" && ", returnValueDescs);
}
sb.AppendLine($"/// <returns>{returnValueDesc}</returns>");
}
else
{
sb.AppendLine("/// <summary>");
sb.AppendLine($"///");
sb.AppendLine("/// </summary>");

var argNames = GetInputArgNames(fullOp);
foreach (var argName in argNames)
{
sb.AppendLine($"/// <param name=\"{argName}\"></param>");
}

sb.AppendLine($"/// <returns></returns>");
}
}

/// <summary>
///
/// </summary>
/// <param name="op">
/// </param>
/// <returns></returns>
/// <remarks></remarks>
public List<string> GetInputArgNames(OpDef op)
{
List<string> names = new();
foreach (var arg in op.InputArg)
{
string argName = arg.Name;
var token = SyntaxFactory.ParseToken(argName);
if (token.IsKeyword())
{
argName = $"{argName}_";
}
names.Add(argName);
}
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
foreach (var (key, typeStr, value) in attrValueDic)
{
var token = SyntaxFactory.ParseToken(key);
string realKey = key;
if (token.IsKeyword())
{
realKey += "_";
}
names.Add(realKey);
}
return names;
}

private static OpList ReadOpDefs(string path)
{
var text = File.ReadAllText(path);
text = RemoveLintTags(text);
text = PreProcessText(text);

string pattern = @"<<END([\s\S]*?)END";

// 定义用于替换的字符串
string replaceStrPrefix = "\"";
string replaceStrSuffix = "\"";

// 将匹配到的文本段全部替换
string replacedText = Regex.Replace(text, pattern, match => {
string matchedText = match.Value;
string innerText = match.Groups[1].Value;
innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks)
.Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符
return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾
}, RegexOptions.Multiline);

var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse<OpList>(replacedText);
return opDefs;
}

static string PreProcessText(string input)
{
int depth = 0;
int endBlockDepth = -1;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < input.Length; i++)
{
char c = input[i];
if (c == '{')
{
depth++;
sb.Append(c);
}
else if (c == '}')
{
if (depth == endBlockDepth)
{
sb.Append("END\n");
endBlockDepth = -1;
}
sb.Append(c);
depth--;
}
else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "<<END")
{
endBlockDepth = depth;
sb.Append("<<END");
i += 4;
}
else if (c == 'E' && i + 3 < input.Length && input.Substring(i, 3) == "END")
{
endBlockDepth = -1;
sb.Append("END");
i += 2;
}
else
{
sb.Append(c);
}
}

string output = sb.ToString();
return output;
}

static string RemoveLintTags(string input)
{
string[] lines = input.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None);
StringBuilder sb = new StringBuilder();
foreach (string line in lines)
{
if (!line.TrimStart().StartsWith("# LINT"))
{
sb.AppendLine(line);
}
}
return sb.ToString().TrimEnd();
}
}
}

+ 6
- 195
Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow.CodeGen
// begin to write main body
sb.AppendLine("var _ctx = tf.Context;");

var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues);
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
// deal with dynamic default values.
foreach(var (name, expr) in dynamicDefaultValues)
{
@@ -183,7 +183,7 @@ namespace Tensorflow.CodeGen
sb.Append($"Tensor {argName}, ");
}
}
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues);
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE"))
{
var token = SyntaxFactory.ParseToken(key);
@@ -226,7 +226,7 @@ namespace Tensorflow.CodeGen
}

sb.Append("}, attrs = new Dictionary<string, object>(){ ");
var attrValueDic = GetAttrsDefaultValue(op, out var _);
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
foreach (var (key, _, _) in attrValueDic)
{
sb.Append($"[\"{key}\"] = {key}, ");
@@ -252,7 +252,7 @@ namespace Tensorflow.CodeGen
}
sb.Append($"{inputArgRealName}, ");
}
var attrValueDic = GetAttrsDefaultValue(op, out var _);
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
foreach (var (key, _, _) in attrValueDic)
{
string keyRealName = key;
@@ -439,7 +439,7 @@ namespace Tensorflow.CodeGen
sb.Append($"Tensor {argName}, ");
}
}
var attrValueDic = GetAttrsDefaultValue(op, out var _);
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
foreach (var (key, typeStr, _) in attrValueDic)
{
var token = SyntaxFactory.ParseToken(key);
@@ -465,7 +465,7 @@ namespace Tensorflow.CodeGen
}
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
}
var attrValueDic = GetAttrsDefaultValue(op, out var _);
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _);
foreach (var (key, _, _) in attrValueDic)
{
sb.AppendLine($"keywords[\"{key}\"] = {key};");
@@ -473,195 +473,6 @@ namespace Tensorflow.CodeGen
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
}

// name, type string, default value
public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues)
{
dynamicDefaultValues = new();
List<(string, string, string)> res = new();
foreach (var attr in op.Attr)
{
if (attr.Type == "type")
{
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name);
if (!found)
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype());
string enumPath = typeof(TF_DataType).Name + "." + name;
res.Add((attr.Name, "TF_DataType", enumPath));
}
else
{
res.Add((attr.Name, "TF_DataType", "NOVALUE"));
}
}
}
else if (attr.Type == "int")
{
if(op.InputArg.Any(x => x.NumberAttr == attr.Name))
{
continue;
}
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I)
{
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString()));
}
else
{
res.Add((attr.Name, "int", "0"));
}
}
else if (attr.Type == "float")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F)
{
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f"));
}
else
{
res.Add((attr.Name, "float", "NOVALUE"));
}
}
else if (attr.Type == "string")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""));
}
else
{
res.Add((attr.Name, "string", "NOVALUE"));
}
}
else if (attr.Type == "bool")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B)
{
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower()));
}
else
{
res.Add((attr.Name, "bool", "NOVALUE"));
}
}
else if (attr.Type == "shape")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape)
{
if (attr.DefaultValue.Shape.UnknownRank)
{
res.Add((attr.Name, "Shape", $"null"));
}
else
{
Shape shape = new Shape(attr.DefaultValue.Shape);
string expression = $"new Shape({string.Join(", ", shape.dims)})";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "Shape", $"null"));
}
}
else
{
res.Add((attr.Name, "Shape", "NOVALUE"));
}
}
else if (attr.Type == "list(type)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
List<TF_DataType> values = new();
foreach (var value in attr.DefaultValue.List.Type)
{
values.Add(value.as_tf_dtype());
}
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "TF_DataType[]", $"null"));
}
else
{
res.Add((attr.Name, "TF_DataType[]", "NOVALUE"));
}
}
else if (attr.Type == "list(shape)")
{
res.Add((attr.Name, "Shape[]", "NOVALUE"));
}
else if (attr.Type == "list(string)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{
List<string> values = new();
foreach (var value in attr.DefaultValue.List.S)
{
values.Add(value.ToStringUtf8());
}
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "string[]", $"null"));
}
else
{
res.Add((attr.Name, "string[]", "NOVALUE"));
}
}
else if (attr.Type == "list(int)")
{
if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<int> values = new();
foreach(var value in attr.DefaultValue.List.I)
{
values.Add((int)value);
}
string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "int[]", $"null"));
}
else
{
res.Add((attr.Name, "int[]", "NOVALUE"));
}
}
else if (attr.Type == "list(float)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<float> values = new();
foreach (var value in attr.DefaultValue.List.F)
{
values.Add(value);
}
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "float[]", $"null"));
}
else
{
res.Add((attr.Name, "float[]", "NOVALUE"));
}
}
else if (attr.Type == "func")
{
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE"));
}
else if (attr.Type == "list(func)")
{
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE"));
}
else if (attr.Type == "tensor")
{
res.Add((attr.Name, "TensorProto", "NOVALUE"));
}
else
{
throw new NotImplementedException();
}
}
return res;
}

private static bool HasRefArgs(OpDef op)
{
return op.InputArg.Any(x => x.IsRef);


+ 13
- 13
Tensorflow.CodeGen/GenOpsWriter.cs View File

@@ -12,16 +12,18 @@ namespace Tensorflow.CodeGen
private string _basePath;
private Dictionary<string, OpDef> _opMap;
private OpClassifier _opClassifier;
private FunctionGenerator _g = new();
private FunctionGenerator _fg = new();
private DescriptionGenerator _dg;

public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename)
public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename)
{
_basePath = basePath;

var opDefs = ReadAllOpDefs(opDefFilename);
var opDefs = Utils.ReadAllOpDefs(opDefFilename);
_opMap = opDefs.Op.ToDictionary(
x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x);
x => Utils.ConvertToUnderscore(x.Name), x => x);
_opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name)));
_dg = new DescriptionGenerator(apiDefFilesDirectory);
}

public void WriteAll()
@@ -53,12 +55,17 @@ namespace Tensorflow.CodeGen
if(_opMap.ContainsKey(funcName))
{
var opDef = _opMap[funcName];
_g.AppendFunction(opDef, sb);

// write the descriptions.
_dg.AppendDescription(opDef, sb);

// write the function body.
_fg.AppendFunction(opDef, sb);
}
else if (funcName.StartsWith("_"))
{
var opDef = _opMap[funcName.Substring(1)];
_g.AppendFunction(opDef, sb);
_fg.AppendFunction(opDef, sb);
}
}

@@ -69,12 +76,5 @@ namespace Tensorflow.CodeGen
File.WriteAllText(fullFilePath, sb.ToString());
}
}

private OpList ReadAllOpDefs(string path)
{
var text = File.ReadAllText(path);
var opDefs = OpList.Parser.ParseText(text);
return opDefs;
}
}
}

+ 1
- 2
Tensorflow.CodeGen/Program.cs View File

@@ -5,10 +5,9 @@ using System.Text;
using System.Xml.Linq;
using Tensorflow.CodeGen;

//Console.WriteLine(Utils.ConvertToUnderscore("LRN"));

GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops",
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt");

writer.WriteAll();

+ 1
- 1
Tensorflow.CodeGen/Tensorflow.CodeGen.csproj View File

@@ -9,11 +9,11 @@

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
<PackageReference Include="TensorFlow.NET" Version="0.100.5" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" />
<ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

</Project>

+ 198
- 1
Tensorflow.CodeGen/Utils.cs View File

@@ -1,4 +1,5 @@
using System;
using Protobuf.Text;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
@@ -51,5 +52,201 @@ namespace Tensorflow.CodeGen

return result.ToString();
}

public static OpList ReadAllOpDefs(string path)
{
var text = File.ReadAllText(path);
var opDefs = OpList.Parser.ParseText(text);
return opDefs;
}

// name, type string, default value
public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues)
{
dynamicDefaultValues = new();
List<(string, string, string)> res = new();
foreach (var attr in op.Attr)
{
if (attr.Type == "type")
{
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name);
if (!found)
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype());
string enumPath = typeof(TF_DataType).Name + "." + name;
res.Add((attr.Name, "TF_DataType", enumPath));
}
else
{
res.Add((attr.Name, "TF_DataType", "NOVALUE"));
}
}
}
else if (attr.Type == "int")
{
if (op.InputArg.Any(x => x.NumberAttr == attr.Name))
{
continue;
}
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I)
{
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString()));
}
else
{
res.Add((attr.Name, "int", "0"));
}
}
else if (attr.Type == "float")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F)
{
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f"));
}
else
{
res.Add((attr.Name, "float", "NOVALUE"));
}
}
else if (attr.Type == "string")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""));
}
else
{
res.Add((attr.Name, "string", "NOVALUE"));
}
}
else if (attr.Type == "bool")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B)
{
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower()));
}
else
{
res.Add((attr.Name, "bool", "NOVALUE"));
}
}
else if (attr.Type == "shape")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape)
{
if (attr.DefaultValue.Shape.UnknownRank)
{
res.Add((attr.Name, "Shape", $"null"));
}
else
{
Shape shape = new Shape(attr.DefaultValue.Shape);
string expression = $"new Shape({string.Join(", ", shape.dims)})";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "Shape", $"null"));
}
}
else
{
res.Add((attr.Name, "Shape", "NOVALUE"));
}
}
else if (attr.Type == "list(type)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
List<TF_DataType> values = new();
foreach (var value in attr.DefaultValue.List.Type)
{
values.Add(value.as_tf_dtype());
}
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "TF_DataType[]", $"null"));
}
else
{
res.Add((attr.Name, "TF_DataType[]", "NOVALUE"));
}
}
else if (attr.Type == "list(shape)")
{
res.Add((attr.Name, "Shape[]", "NOVALUE"));
}
else if (attr.Type == "list(string)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
{
List<string> values = new();
foreach (var value in attr.DefaultValue.List.S)
{
values.Add(value.ToStringUtf8());
}
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "string[]", $"null"));
}
else
{
res.Add((attr.Name, "string[]", "NOVALUE"));
}
}
else if (attr.Type == "list(int)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<int> values = new();
foreach (var value in attr.DefaultValue.List.I)
{
values.Add((int)value);
}
string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "int[]", $"null"));
}
else
{
res.Add((attr.Name, "int[]", "NOVALUE"));
}
}
else if (attr.Type == "list(float)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<float> values = new();
foreach (var value in attr.DefaultValue.List.F)
{
values.Add(value);
}
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "float[]", $"null"));
}
else
{
res.Add((attr.Name, "float[]", "NOVALUE"));
}
}
else if (attr.Type == "func")
{
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE"));
}
else if (attr.Type == "list(func)")
{
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE"));
}
else if (attr.Type == "tensor")
{
res.Add((attr.Name, "TensorProto", "NOVALUE"));
}
else
{
throw new NotImplementedException();
}
}
return res;
}
}
}

Loading…
Cancel
Save