|
@@ -2,6 +2,7 @@ |
|
|
using System.Collections.Generic; |
|
|
using System.Collections.Generic; |
|
|
using System.Diagnostics; |
|
|
using System.Diagnostics; |
|
|
using System.Linq; |
|
|
using System.Linq; |
|
|
|
|
|
using System.Linq.Expressions; |
|
|
using System.Reflection.Metadata.Ecma335; |
|
|
using System.Reflection.Metadata.Ecma335; |
|
|
using System.Text; |
|
|
using System.Text; |
|
|
using System.Threading.Tasks; |
|
|
using System.Threading.Tasks; |
|
@@ -16,17 +17,17 @@ namespace Tensorflow.CodeGen |
|
|
// TODO: add descriptions |
|
|
// TODO: add descriptions |
|
|
sb.Append("public static "); |
|
|
sb.Append("public static "); |
|
|
int outputArgsCount = op.OutputArg.Count; |
|
|
int outputArgsCount = op.OutputArg.Count; |
|
|
if (outputArgsCount > 1) |
|
|
|
|
|
|
|
|
if (outputArgsCount == 0) |
|
|
{ |
|
|
{ |
|
|
sb.Append("Tensor[] "); |
|
|
|
|
|
|
|
|
sb.Append("Operation "); |
|
|
} |
|
|
} |
|
|
else if (outputArgsCount == 1) |
|
|
|
|
|
|
|
|
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) |
|
|
{ |
|
|
{ |
|
|
sb.Append("Tensor "); |
|
|
sb.Append("Tensor "); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
sb.Append("Operation "); |
|
|
|
|
|
|
|
|
sb.Append("Tensor[] "); |
|
|
} |
|
|
} |
|
|
string funcName = Utils.ConvertToUnderscore(op.Name); |
|
|
string funcName = Utils.ConvertToUnderscore(op.Name); |
|
|
var token = SyntaxFactory.ParseToken(funcName); |
|
|
var token = SyntaxFactory.ParseToken(funcName); |
|
@@ -42,6 +43,17 @@ namespace Tensorflow.CodeGen |
|
|
|
|
|
|
|
|
// begin to write main body |
|
|
// begin to write main body |
|
|
sb.AppendLine("var _ctx = tf.Context;"); |
|
|
sb.AppendLine("var _ctx = tf.Context;"); |
|
|
|
|
|
|
|
|
|
|
|
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); |
|
|
|
|
|
// deal with dynamic default values. |
|
|
|
|
|
foreach(var (name, expr) in dynamicDefaultValues) |
|
|
|
|
|
{ |
|
|
|
|
|
sb.AppendLine($"if({name} is null)"); |
|
|
|
|
|
sb.AppendLine("{"); |
|
|
|
|
|
sb.AppendLine($"{name} = {expr};"); |
|
|
|
|
|
sb.AppendLine("}"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
sb.AppendLine("if(_ctx.executing_eagerly()){"); |
|
|
sb.AppendLine("if(_ctx.executing_eagerly()){"); |
|
|
|
|
|
|
|
|
if(HasRefArgs(op)) |
|
|
if(HasRefArgs(op)) |
|
@@ -58,7 +70,7 @@ namespace Tensorflow.CodeGen |
|
|
{ |
|
|
{ |
|
|
sb.AppendLine("return null;"); |
|
|
sb.AppendLine("return null;"); |
|
|
} |
|
|
} |
|
|
else if (outputArgsCount == 1) |
|
|
|
|
|
|
|
|
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) |
|
|
{ |
|
|
{ |
|
|
sb.AppendLine("return _fast_path_result[0];"); |
|
|
sb.AppendLine("return _fast_path_result[0];"); |
|
|
} |
|
|
} |
|
@@ -82,6 +94,17 @@ namespace Tensorflow.CodeGen |
|
|
|
|
|
|
|
|
sb.AppendLine("}"); // if |
|
|
sb.AppendLine("}"); // if |
|
|
|
|
|
|
|
|
|
|
|
foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string")) |
|
|
|
|
|
{ |
|
|
|
|
|
if(value != "NOVALUE") |
|
|
|
|
|
{ |
|
|
|
|
|
sb.AppendLine($"if({name} is null)"); |
|
|
|
|
|
sb.AppendLine("{"); |
|
|
|
|
|
sb.AppendLine($"{name} = {value};"); |
|
|
|
|
|
sb.AppendLine("}"); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// begin to use op helper. |
|
|
// begin to use op helper. |
|
|
AppendOpHelperCall(op, sb); |
|
|
AppendOpHelperCall(op, sb); |
|
|
sb.AppendLine("var _result = _op.outputs;"); |
|
|
sb.AppendLine("var _result = _op.outputs;"); |
|
@@ -126,7 +149,7 @@ namespace Tensorflow.CodeGen |
|
|
{ |
|
|
{ |
|
|
sb.AppendLine("return _op;"); |
|
|
sb.AppendLine("return _op;"); |
|
|
} |
|
|
} |
|
|
else if (outputArgsCount == 1) |
|
|
|
|
|
|
|
|
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) |
|
|
{ |
|
|
{ |
|
|
sb.AppendLine("return _result[0];"); |
|
|
sb.AppendLine("return _result[0];"); |
|
|
} |
|
|
} |
|
@@ -160,8 +183,8 @@ namespace Tensorflow.CodeGen |
|
|
sb.Append($"Tensor {argName}, "); |
|
|
sb.Append($"Tensor {argName}, "); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
|
|
foreach (var (key, (typeStr, value)) in attrValueDic) |
|
|
|
|
|
|
|
|
var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); |
|
|
|
|
|
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) |
|
|
{ |
|
|
{ |
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
string realKey = key; |
|
|
string realKey = key; |
|
@@ -169,21 +192,25 @@ namespace Tensorflow.CodeGen |
|
|
{ |
|
|
{ |
|
|
realKey += "_"; |
|
|
realKey += "_"; |
|
|
} |
|
|
} |
|
|
if (value != "NOVALUE") |
|
|
|
|
|
{ |
|
|
|
|
|
sb.Append($"{typeStr} {realKey} = {value}, "); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
|
|
|
sb.Append($"{typeStr} {realKey}, "); |
|
|
|
|
|
} |
|
|
|
|
|
foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE")) |
|
|
|
|
|
{ |
|
|
|
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
|
|
|
string realKey = key; |
|
|
|
|
|
if (token.IsKeyword()) |
|
|
{ |
|
|
{ |
|
|
sb.Append($"{typeStr} {realKey}, "); |
|
|
|
|
|
|
|
|
realKey += "_"; |
|
|
} |
|
|
} |
|
|
|
|
|
sb.Append($"{typeStr} {realKey} = {value}, "); |
|
|
} |
|
|
} |
|
|
sb.Append($"string? name = null"); |
|
|
sb.Append($"string? name = null"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public void AppendFastPathExecute(OpDef op, StringBuilder sb) |
|
|
public void AppendFastPathExecute(OpDef op, StringBuilder sb) |
|
|
{ |
|
|
{ |
|
|
sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, "); |
|
|
|
|
|
|
|
|
sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)"); |
|
|
|
|
|
sb.Append("{ args = new object[]{ "); |
|
|
foreach (var arg in op.InputArg) |
|
|
foreach (var arg in op.InputArg) |
|
|
{ |
|
|
{ |
|
|
string attrArgName = arg.Name; |
|
|
string attrArgName = arg.Name; |
|
@@ -193,16 +220,23 @@ namespace Tensorflow.CodeGen |
|
|
} |
|
|
} |
|
|
sb.Append($"{attrArgName}, "); |
|
|
sb.Append($"{attrArgName}, "); |
|
|
} |
|
|
} |
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
|
|
foreach (var (key, _) in attrValueDic) |
|
|
|
|
|
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
{ |
|
|
{ |
|
|
sb.Append($"\"{key}\", {key}, "); |
|
|
|
|
|
|
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
sb.Append("}, attrs = new Dictionary<string, object>(){ "); |
|
|
|
|
|
var attrValueDic = GetAttrsDefaultValue(op, out var _); |
|
|
|
|
|
foreach (var (key, _, _) in attrValueDic) |
|
|
|
|
|
{ |
|
|
|
|
|
sb.Append($"[\"{key}\"] = {key}, "); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
{ |
|
|
{ |
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
} |
|
|
} |
|
|
sb.Append("));\n"); |
|
|
|
|
|
|
|
|
sb.Append("}});\n"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) |
|
|
public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) |
|
@@ -218,8 +252,8 @@ namespace Tensorflow.CodeGen |
|
|
} |
|
|
} |
|
|
sb.Append($"{inputArgRealName}, "); |
|
|
sb.Append($"{inputArgRealName}, "); |
|
|
} |
|
|
} |
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
|
|
foreach (var (key, _) in attrValueDic) |
|
|
|
|
|
|
|
|
var attrValueDic = GetAttrsDefaultValue(op, out var _); |
|
|
|
|
|
foreach (var (key, _, _) in attrValueDic) |
|
|
{ |
|
|
{ |
|
|
string keyRealName = key; |
|
|
string keyRealName = key; |
|
|
if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) |
|
|
if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) |
|
@@ -233,11 +267,19 @@ namespace Tensorflow.CodeGen |
|
|
|
|
|
|
|
|
public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) |
|
|
public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) |
|
|
{ |
|
|
{ |
|
|
sb.Append("public static Tensor"); |
|
|
|
|
|
|
|
|
sb.Append("public static "); |
|
|
int outputArgsCount = op.OutputArg.Count; |
|
|
int outputArgsCount = op.OutputArg.Count; |
|
|
if (outputArgsCount > 1) |
|
|
|
|
|
|
|
|
if (outputArgsCount == 0) |
|
|
|
|
|
{ |
|
|
|
|
|
sb.Append("Operation "); |
|
|
|
|
|
} |
|
|
|
|
|
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
sb.Append("Tensor "); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
{ |
|
|
{ |
|
|
sb.Append("[]"); |
|
|
|
|
|
|
|
|
sb.Append("Tensor[] "); |
|
|
} |
|
|
} |
|
|
string opName = op.Name; |
|
|
string opName = op.Name; |
|
|
string funcName = Utils.ConvertToUnderscore(op.Name); |
|
|
string funcName = Utils.ConvertToUnderscore(op.Name); |
|
@@ -254,24 +296,47 @@ namespace Tensorflow.CodeGen |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); |
|
|
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
|
|
|
|
|
if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr))) |
|
|
{ |
|
|
{ |
|
|
string realArgName = arg.Name; |
|
|
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
|
|
|
|
|
sb.AppendLine("List<Tensor> _inputs_flat_list = new();"); |
|
|
|
|
|
foreach (var arg in op.InputArg) |
|
|
{ |
|
|
{ |
|
|
realArgName = $"{realArgName}_"; |
|
|
|
|
|
|
|
|
string realArgName = arg.Name; |
|
|
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
|
|
{ |
|
|
|
|
|
realArgName = $"{realArgName}_"; |
|
|
|
|
|
} |
|
|
|
|
|
if (string.IsNullOrEmpty(arg.NumberAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
sb.AppendLine($"_inputs_flat_list.Add({realArgName});"); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});"); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
sb.Append($"{realArgName}, "); |
|
|
|
|
|
|
|
|
sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();"); |
|
|
} |
|
|
} |
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
|
|
|
|
|
|
else |
|
|
{ |
|
|
{ |
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
|
|
|
|
|
sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); |
|
|
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
|
|
{ |
|
|
|
|
|
string realArgName = arg.Name; |
|
|
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
|
|
{ |
|
|
|
|
|
realArgName = $"{realArgName}_"; |
|
|
|
|
|
} |
|
|
|
|
|
sb.Append($"{realArgName}, "); |
|
|
|
|
|
} |
|
|
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
|
|
|
{ |
|
|
|
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
|
|
} |
|
|
|
|
|
sb.Append("};\n"); |
|
|
} |
|
|
} |
|
|
sb.Append("};\n"); |
|
|
|
|
|
|
|
|
|
|
|
sb.Append("object[] _attrs = new object[]{"); |
|
|
sb.Append("object[] _attrs = new object[]{"); |
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
|
|
foreach (var attr in op.Attr) |
|
|
foreach (var attr in op.Attr) |
|
|
{ |
|
|
{ |
|
|
if (attr.Type == "type") |
|
|
if (attr.Type == "type") |
|
@@ -293,27 +358,15 @@ namespace Tensorflow.CodeGen |
|
|
} |
|
|
} |
|
|
if (!found) |
|
|
if (!found) |
|
|
{ |
|
|
{ |
|
|
if (attr.Name.StartsWith("T") && attr.Name.Length > 1) |
|
|
|
|
|
{ |
|
|
|
|
|
string paramName = attr.Name.Substring(1); |
|
|
|
|
|
if (SyntaxFactory.ParseToken(paramName).IsKeyword()) |
|
|
|
|
|
{ |
|
|
|
|
|
paramName = $"{paramName}_"; |
|
|
|
|
|
} |
|
|
|
|
|
sb.Append($"\"{attr.Name}\", {paramName}.dtype, "); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
|
|
|
string attrRealName = attr.Name; |
|
|
|
|
|
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) |
|
|
{ |
|
|
{ |
|
|
string attrRealName = attr.Name; |
|
|
|
|
|
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) |
|
|
|
|
|
{ |
|
|
|
|
|
attrRealName = $"{attrRealName}_"; |
|
|
|
|
|
} |
|
|
|
|
|
sb.Append($"\"{attr.Name}\", {attrRealName}, "); |
|
|
|
|
|
|
|
|
attrRealName = $"{attrRealName}_"; |
|
|
} |
|
|
} |
|
|
|
|
|
sb.Append($"\"{attr.Name}\", {attrRealName}, "); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))) |
|
|
|
|
|
|
|
|
else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) |
|
|
{ |
|
|
{ |
|
|
bool found = false; |
|
|
bool found = false; |
|
|
foreach (var arg in op.InputArg) |
|
|
foreach (var arg in op.InputArg) |
|
@@ -355,7 +408,7 @@ namespace Tensorflow.CodeGen |
|
|
{ |
|
|
{ |
|
|
sb.AppendLine("return null;"); |
|
|
sb.AppendLine("return null;"); |
|
|
} |
|
|
} |
|
|
else if (outputArgsCount == 1) |
|
|
|
|
|
|
|
|
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) |
|
|
{ |
|
|
{ |
|
|
sb.AppendLine("return _result[0];"); |
|
|
sb.AppendLine("return _result[0];"); |
|
|
} |
|
|
} |
|
@@ -386,8 +439,8 @@ namespace Tensorflow.CodeGen |
|
|
sb.Append($"Tensor {argName}, "); |
|
|
sb.Append($"Tensor {argName}, "); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
|
|
foreach (var (key, (typeStr, _)) in attrValueDic) |
|
|
|
|
|
|
|
|
var attrValueDic = GetAttrsDefaultValue(op, out var _); |
|
|
|
|
|
foreach (var (key, typeStr, _) in attrValueDic) |
|
|
{ |
|
|
{ |
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
string realKey = key; |
|
|
string realKey = key; |
|
@@ -412,18 +465,19 @@ namespace Tensorflow.CodeGen |
|
|
} |
|
|
} |
|
|
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); |
|
|
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); |
|
|
} |
|
|
} |
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
|
|
foreach (var (key, _) in attrValueDic) |
|
|
|
|
|
|
|
|
var attrValueDic = GetAttrsDefaultValue(op, out var _); |
|
|
|
|
|
foreach (var (key, _, _) in attrValueDic) |
|
|
{ |
|
|
{ |
|
|
sb.Append($"keywords[\"{key}\"] = {key};"); |
|
|
|
|
|
|
|
|
sb.AppendLine($"keywords[\"{key}\"] = {key};"); |
|
|
} |
|
|
} |
|
|
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); |
|
|
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// key, (type string, default value) |
|
|
|
|
|
public Dictionary<string, (string, string)> GetAttrsDefaultValue(OpDef op) |
|
|
|
|
|
|
|
|
// name, type string, default value |
|
|
|
|
|
public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) |
|
|
{ |
|
|
{ |
|
|
Dictionary<string, (string, string)> dic = new(); |
|
|
|
|
|
|
|
|
dynamicDefaultValues = new(); |
|
|
|
|
|
List<(string, string, string)> res = new(); |
|
|
foreach (var attr in op.Attr) |
|
|
foreach (var attr in op.Attr) |
|
|
{ |
|
|
{ |
|
|
if (attr.Type == "type") |
|
|
if (attr.Type == "type") |
|
@@ -435,111 +489,177 @@ namespace Tensorflow.CodeGen |
|
|
{ |
|
|
{ |
|
|
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); |
|
|
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); |
|
|
string enumPath = typeof(TF_DataType).Name + "." + name; |
|
|
string enumPath = typeof(TF_DataType).Name + "." + name; |
|
|
dic[attr.Name] = ("TF_DataType", enumPath); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "TF_DataType", enumPath)); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("TF_DataType", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "TF_DataType", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "int") |
|
|
else if (attr.Type == "int") |
|
|
{ |
|
|
{ |
|
|
if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)) |
|
|
|
|
|
|
|
|
if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) |
|
|
{ |
|
|
{ |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("int", attr.DefaultValue.I.ToString()); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("int", "0"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "int", "0")); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "float") |
|
|
else if (attr.Type == "float") |
|
|
{ |
|
|
{ |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("float", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "float", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "string") |
|
|
else if (attr.Type == "string") |
|
|
{ |
|
|
{ |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("string", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "string", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "bool") |
|
|
else if (attr.Type == "bool") |
|
|
{ |
|
|
{ |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower()); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("bool", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "bool", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "shape") |
|
|
else if (attr.Type == "shape") |
|
|
{ |
|
|
{ |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) |
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("Shape", $"null"); |
|
|
|
|
|
|
|
|
if (attr.DefaultValue.Shape.UnknownRank) |
|
|
|
|
|
{ |
|
|
|
|
|
res.Add((attr.Name, "Shape", $"null")); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
Shape shape = new Shape(attr.DefaultValue.Shape); |
|
|
|
|
|
string expression = $"new Shape({string.Join(", ", shape.dims)})"; |
|
|
|
|
|
dynamicDefaultValues[attr.Name] = expression; |
|
|
|
|
|
res.Add((attr.Name, "Shape", $"null")); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("Shape", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "Shape", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "list(type)") |
|
|
else if (attr.Type == "list(type)") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("TF_DataType[]", "NOVALUE"); |
|
|
|
|
|
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) |
|
|
|
|
|
{ |
|
|
|
|
|
List<TF_DataType> values = new(); |
|
|
|
|
|
foreach (var value in attr.DefaultValue.List.Type) |
|
|
|
|
|
{ |
|
|
|
|
|
values.Add(value.as_tf_dtype()); |
|
|
|
|
|
} |
|
|
|
|
|
string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; |
|
|
|
|
|
dynamicDefaultValues[attr.Name] = expression; |
|
|
|
|
|
res.Add((attr.Name, "TF_DataType[]", $"null")); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "list(shape)") |
|
|
else if (attr.Type == "list(shape)") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("Shape[]", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "Shape[]", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "list(string)") |
|
|
else if (attr.Type == "list(string)") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("string[]", "NOVALUE"); |
|
|
|
|
|
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) |
|
|
|
|
|
{ |
|
|
|
|
|
List<string> values = new(); |
|
|
|
|
|
foreach (var value in attr.DefaultValue.List.S) |
|
|
|
|
|
{ |
|
|
|
|
|
values.Add(value.ToStringUtf8()); |
|
|
|
|
|
} |
|
|
|
|
|
string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; |
|
|
|
|
|
dynamicDefaultValues[attr.Name] = expression; |
|
|
|
|
|
res.Add((attr.Name, "string[]", $"null")); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
res.Add((attr.Name, "string[]", "NOVALUE")); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "list(int)") |
|
|
else if (attr.Type == "list(int)") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("int[]", "NOVALUE"); |
|
|
|
|
|
|
|
|
if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) |
|
|
|
|
|
{ |
|
|
|
|
|
List<int> values = new(); |
|
|
|
|
|
foreach(var value in attr.DefaultValue.List.I) |
|
|
|
|
|
{ |
|
|
|
|
|
values.Add((int)value); |
|
|
|
|
|
} |
|
|
|
|
|
string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; |
|
|
|
|
|
dynamicDefaultValues[attr.Name] = expression; |
|
|
|
|
|
res.Add((attr.Name, "int[]", $"null")); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
res.Add((attr.Name, "int[]", "NOVALUE")); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "list(float)") |
|
|
else if (attr.Type == "list(float)") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("float[]", "NOVALUE"); |
|
|
|
|
|
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) |
|
|
|
|
|
{ |
|
|
|
|
|
List<float> values = new(); |
|
|
|
|
|
foreach (var value in attr.DefaultValue.List.F) |
|
|
|
|
|
{ |
|
|
|
|
|
values.Add(value); |
|
|
|
|
|
} |
|
|
|
|
|
string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; |
|
|
|
|
|
dynamicDefaultValues[attr.Name] = expression; |
|
|
|
|
|
res.Add((attr.Name, "float[]", $"null")); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
res.Add((attr.Name, "float[]", "NOVALUE")); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "func") |
|
|
else if (attr.Type == "func") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("Func<Tensors, Tensors>", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "list(func)") |
|
|
else if (attr.Type == "list(func)") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("Func<Tensors, Tensors>[]", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
else if (attr.Type == "tensor") |
|
|
else if (attr.Type == "tensor") |
|
|
{ |
|
|
{ |
|
|
dic[attr.Name] = ("TensorProto", "NOVALUE"); |
|
|
|
|
|
|
|
|
res.Add((attr.Name, "TensorProto", "NOVALUE")); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
throw new NotImplementedException(); |
|
|
throw new NotImplementedException(); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return dic; |
|
|
|
|
|
|
|
|
return res; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private static bool HasRefArgs(OpDef op) |
|
|
private static bool HasRefArgs(OpDef op) |
|
|