|
|
@@ -0,0 +1,550 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Diagnostics; |
|
|
|
using System.Linq; |
|
|
|
using System.Reflection.Metadata.Ecma335; |
|
|
|
using System.Text; |
|
|
|
using System.Threading.Tasks; |
|
|
|
using Microsoft.CodeAnalysis.CSharp; |
|
|
|
|
|
|
|
namespace Tensorflow.CodeGen |
|
|
|
{ |
|
|
|
public class FunctionGenerator |
|
|
|
{ |
|
|
|
public void AppendFunction(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
// TODO: add descriptions |
|
|
|
sb.Append("public static "); |
|
|
|
int outputArgsCount = op.OutputArg.Count; |
|
|
|
if (outputArgsCount > 1) |
|
|
|
{ |
|
|
|
sb.Append("Tensor[] "); |
|
|
|
} |
|
|
|
else if (outputArgsCount == 1) |
|
|
|
{ |
|
|
|
sb.Append("Tensor "); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append("Operation "); |
|
|
|
} |
|
|
|
string funcName = Utils.ConvertToUnderscore(op.Name); |
|
|
|
var token = SyntaxFactory.ParseToken(funcName); |
|
|
|
if (token.IsKeyword()) |
|
|
|
{ |
|
|
|
funcName = $"_{funcName}"; |
|
|
|
} |
|
|
|
sb.Append($" {funcName}("); |
|
|
|
|
|
|
|
// define args |
|
|
|
AppendArgs(op, sb); |
|
|
|
sb.Append(")\n{\n"); |
|
|
|
|
|
|
|
// begin to write main body |
|
|
|
sb.AppendLine("var _ctx = tf.Context;"); |
|
|
|
sb.AppendLine("if(_ctx.executing_eagerly()){"); |
|
|
|
|
|
|
|
if(HasRefArgs(op)) |
|
|
|
{ |
|
|
|
var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); |
|
|
|
sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append("try\n{\n"); |
|
|
|
|
|
|
|
AppendFastPathExecute(op, sb); |
|
|
|
if (outputArgsCount == 0) |
|
|
|
{ |
|
|
|
sb.AppendLine("return null;"); |
|
|
|
} |
|
|
|
else if (outputArgsCount == 1) |
|
|
|
{ |
|
|
|
sb.AppendLine("return _fast_path_result[0];"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.AppendLine("return _fast_path_result;"); |
|
|
|
} |
|
|
|
|
|
|
|
sb.AppendLine("}"); // try |
|
|
|
|
|
|
|
sb.Append("catch(Exception)\n{\n"); |
|
|
|
sb.AppendLine("}"); // catch |
|
|
|
|
|
|
|
sb.Append("try\n{\n"); |
|
|
|
AppendEagerFallbackCall(op, sb); |
|
|
|
sb.AppendLine("}"); // try |
|
|
|
|
|
|
|
sb.Append("catch(Exception)\n{\n"); |
|
|
|
sb.AppendLine("}"); // catch |
|
|
|
} |
|
|
|
|
|
|
|
sb.AppendLine("}"); // if |
|
|
|
|
|
|
|
// begin to use op helper. |
|
|
|
AppendOpHelperCall(op, sb); |
|
|
|
sb.AppendLine("var _result = _op.outputs;"); |
|
|
|
|
|
|
|
// check if it needs to record gradient. |
|
|
|
sb.Append("if(_execute.must_record_gradient())\n{\n"); |
|
|
|
sb.Append("object[] _attrs = new object[]{"); |
|
|
|
foreach (var attr in op.Attr) |
|
|
|
{ |
|
|
|
string attrRealName = attr.Name; |
|
|
|
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) |
|
|
|
{ |
|
|
|
attrRealName += "_"; |
|
|
|
} |
|
|
|
if (attr.Type == "type") |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), "); |
|
|
|
} |
|
|
|
else if (attr.Type == "int") |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), "); |
|
|
|
} |
|
|
|
else if (attr.Type == "bool") |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), "); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), "); |
|
|
|
} |
|
|
|
} |
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
|
{ |
|
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
} |
|
|
|
sb.Append("};\n"); |
|
|
|
sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);"); |
|
|
|
|
|
|
|
sb.AppendLine("}"); // if |
|
|
|
|
|
|
|
if (outputArgsCount == 0) |
|
|
|
{ |
|
|
|
sb.AppendLine("return _op;"); |
|
|
|
} |
|
|
|
else if (outputArgsCount == 1) |
|
|
|
{ |
|
|
|
sb.AppendLine("return _result[0];"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.AppendLine("return _result;"); |
|
|
|
} |
|
|
|
sb.AppendLine("}"); // body |
|
|
|
|
|
|
|
sb.AppendLine(); |
|
|
|
|
|
|
|
AppendEagerFallbackDefinition(op, sb); |
|
|
|
} |
|
|
|
|
|
|
|
public void AppendArgs(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string argName = arg.Name; |
|
|
|
var token = SyntaxFactory.ParseToken(argName); |
|
|
|
if (token.IsKeyword()) |
|
|
|
{ |
|
|
|
argName = $"{argName}_"; |
|
|
|
} |
|
|
|
if (!string.IsNullOrEmpty(arg.NumberAttr)) |
|
|
|
{ |
|
|
|
sb.Append($"Tensors {argName}, "); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append($"Tensor {argName}, "); |
|
|
|
} |
|
|
|
} |
|
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
foreach (var (key, (typeStr, value)) in attrValueDic) |
|
|
|
{ |
|
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
|
string realKey = key; |
|
|
|
if (token.IsKeyword()) |
|
|
|
{ |
|
|
|
realKey += "_"; |
|
|
|
} |
|
|
|
if (value != "NOVALUE") |
|
|
|
{ |
|
|
|
sb.Append($"{typeStr} {realKey} = {value}, "); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append($"{typeStr} {realKey}, "); |
|
|
|
} |
|
|
|
} |
|
|
|
sb.Append($"string? name = null"); |
|
|
|
} |
|
|
|
|
|
|
|
public void AppendFastPathExecute(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, "); |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string attrArgName = arg.Name; |
|
|
|
if (SyntaxFactory.ParseToken(attrArgName).IsKeyword()) |
|
|
|
{ |
|
|
|
attrArgName += "_"; |
|
|
|
} |
|
|
|
sb.Append($"{attrArgName}, "); |
|
|
|
} |
|
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
foreach (var (key, _) in attrValueDic) |
|
|
|
{ |
|
|
|
sb.Append($"\"{key}\", {key}, "); |
|
|
|
} |
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
|
{ |
|
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
} |
|
|
|
sb.Append("));\n"); |
|
|
|
} |
|
|
|
|
|
|
|
public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback"; |
|
|
|
sb.Append($"return {funcName}("); |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string inputArgRealName = arg.Name; |
|
|
|
if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword()) |
|
|
|
{ |
|
|
|
inputArgRealName += "_"; |
|
|
|
} |
|
|
|
sb.Append($"{inputArgRealName}, "); |
|
|
|
} |
|
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
foreach (var (key, _) in attrValueDic) |
|
|
|
{ |
|
|
|
string keyRealName = key; |
|
|
|
if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) |
|
|
|
{ |
|
|
|
keyRealName += "_"; |
|
|
|
} |
|
|
|
sb.Append($"{key}: {keyRealName}, "); |
|
|
|
} |
|
|
|
sb.Append("name: name, ctx: _ctx);\n"); |
|
|
|
} |
|
|
|
|
|
|
|
public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
sb.Append("public static Tensor"); |
|
|
|
int outputArgsCount = op.OutputArg.Count; |
|
|
|
if (outputArgsCount > 1) |
|
|
|
{ |
|
|
|
sb.Append("[]"); |
|
|
|
} |
|
|
|
string opName = op.Name; |
|
|
|
string funcName = Utils.ConvertToUnderscore(op.Name); |
|
|
|
sb.Append($" {funcName}_eager_fallback("); |
|
|
|
AppendFallBackFunctionArgs(op, sb); |
|
|
|
sb.Append(")\n{\n"); |
|
|
|
|
|
|
|
var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); |
|
|
|
if (possibleRefArg is not null) |
|
|
|
{ |
|
|
|
sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." + |
|
|
|
$" Arg '{possibleRefArg.Name}' is a ref.\");"); |
|
|
|
sb.AppendLine("}"); // body |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string realArgName = arg.Name; |
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
{ |
|
|
|
realArgName = $"{realArgName}_"; |
|
|
|
} |
|
|
|
sb.Append($"{realArgName}, "); |
|
|
|
} |
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
|
{ |
|
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
} |
|
|
|
sb.Append("};\n"); |
|
|
|
|
|
|
|
sb.Append("object[] _attrs = new object[]{"); |
|
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
foreach (var attr in op.Attr) |
|
|
|
{ |
|
|
|
if (attr.Type == "type") |
|
|
|
{ |
|
|
|
bool found = false; |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string realArgName = arg.Name; |
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
{ |
|
|
|
realArgName = $"{realArgName}_"; |
|
|
|
} |
|
|
|
if (arg.TypeAttr == attr.Name) |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", {realArgName}.dtype, "); |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!found) |
|
|
|
{ |
|
|
|
if (attr.Name.StartsWith("T") && attr.Name.Length > 1) |
|
|
|
{ |
|
|
|
string paramName = attr.Name.Substring(1); |
|
|
|
if (SyntaxFactory.ParseToken(paramName).IsKeyword()) |
|
|
|
{ |
|
|
|
paramName = $"{paramName}_"; |
|
|
|
} |
|
|
|
sb.Append($"\"{attr.Name}\", {paramName}.dtype, "); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
string attrRealName = attr.Name; |
|
|
|
if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) |
|
|
|
{ |
|
|
|
attrRealName = $"{attrRealName}_"; |
|
|
|
} |
|
|
|
sb.Append($"\"{attr.Name}\", {attrRealName}, "); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))) |
|
|
|
{ |
|
|
|
bool found = false; |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string realArgName = arg.Name; |
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
{ |
|
|
|
realArgName = $"{realArgName}_"; |
|
|
|
} |
|
|
|
if (arg.NumberAttr == attr.Name) |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", {realArgName}.Length, "); |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append($"\"{attr.Name}\", {attr.Name}, "); |
|
|
|
} |
|
|
|
} |
|
|
|
if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') |
|
|
|
{ |
|
|
|
sb.Remove(sb.Length - 2, 2); |
|
|
|
} |
|
|
|
sb.Append("};\n"); |
|
|
|
|
|
|
|
sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " + |
|
|
|
$"attrs: _attrs, ctx: ctx, name: name);"); |
|
|
|
|
|
|
|
sb.Append("if(_execute.must_record_gradient())\n{\n"); |
|
|
|
|
|
|
|
sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);"); |
|
|
|
|
|
|
|
sb.AppendLine("}"); // if |
|
|
|
|
|
|
|
if (outputArgsCount == 0) |
|
|
|
{ |
|
|
|
sb.AppendLine("return null;"); |
|
|
|
} |
|
|
|
else if (outputArgsCount == 1) |
|
|
|
{ |
|
|
|
sb.AppendLine("return _result[0];"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.AppendLine("return _result;"); |
|
|
|
} |
|
|
|
|
|
|
|
sb.AppendLine("}"); // body |
|
|
|
} |
|
|
|
|
|
|
|
public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string argName = arg.Name; |
|
|
|
var token = SyntaxFactory.ParseToken(argName); |
|
|
|
if (token.IsKeyword()) |
|
|
|
{ |
|
|
|
argName = $"{argName}_"; |
|
|
|
} |
|
|
|
if (!string.IsNullOrEmpty(arg.NumberAttr)) |
|
|
|
{ |
|
|
|
sb.Append($"Tensors {argName}, "); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
sb.Append($"Tensor {argName}, "); |
|
|
|
} |
|
|
|
} |
|
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
foreach (var (key, (typeStr, _)) in attrValueDic) |
|
|
|
{ |
|
|
|
var token = SyntaxFactory.ParseToken(key); |
|
|
|
string realKey = key; |
|
|
|
if (token.IsKeyword()) |
|
|
|
{ |
|
|
|
realKey += "_"; |
|
|
|
} |
|
|
|
sb.Append($"{typeStr} {realKey}, "); |
|
|
|
} |
|
|
|
sb.Append($"string name, Context ctx"); |
|
|
|
} |
|
|
|
|
|
|
|
public void AppendOpHelperCall(OpDef op, StringBuilder sb) |
|
|
|
{ |
|
|
|
sb.AppendLine("Dictionary<string, object> keywords = new();"); |
|
|
|
foreach (var arg in op.InputArg) |
|
|
|
{ |
|
|
|
string realArgName = arg.Name; |
|
|
|
if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) |
|
|
|
{ |
|
|
|
realArgName += "_"; |
|
|
|
} |
|
|
|
sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); |
|
|
|
} |
|
|
|
var attrValueDic = GetAttrsDefaultValue(op); |
|
|
|
foreach (var (key, _) in attrValueDic) |
|
|
|
{ |
|
|
|
sb.Append($"keywords[\"{key}\"] = {key};"); |
|
|
|
} |
|
|
|
sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); |
|
|
|
} |
|
|
|
|
|
|
|
// key, (type string, default value) |
|
|
|
public Dictionary<string, (string, string)> GetAttrsDefaultValue(OpDef op) |
|
|
|
{ |
|
|
|
Dictionary<string, (string, string)> dic = new(); |
|
|
|
foreach (var attr in op.Attr) |
|
|
|
{ |
|
|
|
if (attr.Type == "type") |
|
|
|
{ |
|
|
|
bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); |
|
|
|
if (!found) |
|
|
|
{ |
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) |
|
|
|
{ |
|
|
|
string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); |
|
|
|
string enumPath = typeof(TF_DataType).Name + "." + name; |
|
|
|
dic[attr.Name] = ("TF_DataType", enumPath); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("TF_DataType", "NOVALUE"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
else if (attr.Type == "int") |
|
|
|
{ |
|
|
|
if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)) |
|
|
|
{ |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("int", attr.DefaultValue.I.ToString()); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("int", "0"); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (attr.Type == "float") |
|
|
|
{ |
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("float", "NOVALUE"); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (attr.Type == "string") |
|
|
|
{ |
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("string", "NOVALUE"); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (attr.Type == "bool") |
|
|
|
{ |
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower()); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("bool", "NOVALUE"); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (attr.Type == "shape") |
|
|
|
{ |
|
|
|
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("Shape", $"null"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("Shape", "NOVALUE"); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (attr.Type == "list(type)") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("TF_DataType[]", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "list(shape)") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("Shape[]", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "list(string)") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("string[]", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "list(int)") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("int[]", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "list(float)") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("float[]", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "func") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("Func<Tensors, Tensors>", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "list(func)") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("Func<Tensors, Tensors>[]", "NOVALUE"); |
|
|
|
} |
|
|
|
else if (attr.Type == "tensor") |
|
|
|
{ |
|
|
|
dic[attr.Name] = ("TensorProto", "NOVALUE"); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
throw new NotImplementedException(); |
|
|
|
} |
|
|
|
} |
|
|
|
return dic; |
|
|
|
} |
|
|
|
|
|
|
|
private static bool HasRefArgs(OpDef op) |
|
|
|
{ |
|
|
|
return op.InputArg.Any(x => x.IsRef); |
|
|
|
} |
|
|
|
} |
|
|
|
} |