|
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.Linq;
- using System.Linq.Expressions;
- using System.Reflection.Metadata.Ecma335;
- using System.Text;
- using System.Threading.Tasks;
- using Microsoft.CodeAnalysis.CSharp;
-
- namespace Tensorflow.CodeGen
- {
- public class FunctionGenerator
- {
- public void AppendFunction(OpDef op, StringBuilder sb)
- {
- // TODO: add descriptions
- sb.Append("public static ");
- int outputArgsCount = op.OutputArg.Count;
- if (outputArgsCount == 0)
- {
- sb.Append("Operation ");
- }
- else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
- {
- sb.Append("Tensor ");
- }
- else
- {
- sb.Append("Tensor[] ");
- }
- string funcName = Utils.ConvertToUnderscore(op.Name);
- var token = SyntaxFactory.ParseToken(funcName);
- if (token.IsKeyword())
- {
- funcName = $"_{funcName}";
- }
- sb.Append($" {funcName}(");
-
- // define args
- AppendArgs(op, sb);
- sb.Append(")\n{\n");
-
- // begin to write main body
- sb.AppendLine("var _ctx = tf.Context;");
-
- var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues);
- // deal with dynamic default values.
- foreach(var (name, expr) in dynamicDefaultValues)
- {
- sb.AppendLine($"if({name} is null)");
- sb.AppendLine("{");
- sb.AppendLine($"{name} = {expr};");
- sb.AppendLine("}");
- }
-
- sb.AppendLine("if(_ctx.executing_eagerly()){");
-
- if(HasRefArgs(op))
- {
- var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
- sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");");
- }
- else
- {
- sb.Append("try\n{\n");
-
- AppendFastPathExecute(op, sb);
- if (outputArgsCount == 0)
- {
- sb.AppendLine("return null;");
- }
- else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
- {
- sb.AppendLine("return _fast_path_result[0];");
- }
- else
- {
- sb.AppendLine("return _fast_path_result;");
- }
-
- sb.AppendLine("}"); // try
-
- sb.Append("catch(Exception)\n{\n");
- sb.AppendLine("}"); // catch
-
- sb.Append("try\n{\n");
- AppendEagerFallbackCall(op, sb);
- sb.AppendLine("}"); // try
-
- sb.Append("catch(Exception)\n{\n");
- sb.AppendLine("}"); // catch
- }
-
- sb.AppendLine("}"); // if
-
- foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string"))
- {
- if(value != "NOVALUE")
- {
- sb.AppendLine($"if({name} is null)");
- sb.AppendLine("{");
- sb.AppendLine($"{name} = {value};");
- sb.AppendLine("}");
- }
- }
-
- // begin to use op helper.
- AppendOpHelperCall(op, sb);
- sb.AppendLine("var _result = _op.outputs;");
-
- // check if it needs to record gradient.
- sb.Append("if(_execute.must_record_gradient())\n{\n");
- sb.Append("object[] _attrs = new object[]{");
- foreach (var attr in op.Attr)
- {
- string attrRealName = attr.Name;
- if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
- {
- attrRealName += "_";
- }
- if (attr.Type == "type")
- {
- sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), ");
- }
- else if (attr.Type == "int")
- {
- sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), ");
- }
- else if (attr.Type == "bool")
- {
- sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), ");
- }
- else
- {
- sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), ");
- }
- }
- if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
- {
- sb.Remove(sb.Length - 2, 2);
- }
- sb.Append("};\n");
- sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);");
-
- sb.AppendLine("}"); // if
-
- if (outputArgsCount == 0)
- {
- sb.AppendLine("return _op;");
- }
- else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
- {
- sb.AppendLine("return _result[0];");
- }
- else
- {
- sb.AppendLine("return _result;");
- }
- sb.AppendLine("}"); // body
-
- sb.AppendLine();
-
- AppendEagerFallbackDefinition(op, sb);
- }
-
- public void AppendArgs(OpDef op, StringBuilder sb)
- {
- foreach (var arg in op.InputArg)
- {
- string argName = arg.Name;
- var token = SyntaxFactory.ParseToken(argName);
- if (token.IsKeyword())
- {
- argName = $"{argName}_";
- }
- if (!string.IsNullOrEmpty(arg.NumberAttr))
- {
- sb.Append($"Tensors {argName}, ");
- }
- else
- {
- sb.Append($"Tensor {argName}, ");
- }
- }
- var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues);
- foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE"))
- {
- var token = SyntaxFactory.ParseToken(key);
- string realKey = key;
- if (token.IsKeyword())
- {
- realKey += "_";
- }
- sb.Append($"{typeStr} {realKey}, ");
- }
- foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE"))
- {
- var token = SyntaxFactory.ParseToken(key);
- string realKey = key;
- if (token.IsKeyword())
- {
- realKey += "_";
- }
- sb.Append($"{typeStr} {realKey} = {value}, ");
- }
- sb.Append($"string? name = null");
- }
-
- public void AppendFastPathExecute(OpDef op, StringBuilder sb)
- {
- sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)");
- sb.Append("{ args = new object[]{ ");
- foreach (var arg in op.InputArg)
- {
- string attrArgName = arg.Name;
- if (SyntaxFactory.ParseToken(attrArgName).IsKeyword())
- {
- attrArgName += "_";
- }
- sb.Append($"{attrArgName}, ");
- }
- if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
- {
- sb.Remove(sb.Length - 2, 2);
- }
-
- sb.Append("}, attrs = new Dictionary<string, object>(){ ");
- var attrValueDic = GetAttrsDefaultValue(op, out var _);
- foreach (var (key, _, _) in attrValueDic)
- {
- sb.Append($"[\"{key}\"] = {key}, ");
- }
-
- if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
- {
- sb.Remove(sb.Length - 2, 2);
- }
- sb.Append("}});\n");
- }
-
- public void AppendEagerFallbackCall(OpDef op, StringBuilder sb)
- {
- string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback";
- sb.Append($"return {funcName}(");
- foreach (var arg in op.InputArg)
- {
- string inputArgRealName = arg.Name;
- if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword())
- {
- inputArgRealName += "_";
- }
- sb.Append($"{inputArgRealName}, ");
- }
- var attrValueDic = GetAttrsDefaultValue(op, out var _);
- foreach (var (key, _, _) in attrValueDic)
- {
- string keyRealName = key;
- if (SyntaxFactory.ParseToken(keyRealName).IsKeyword())
- {
- keyRealName += "_";
- }
- sb.Append($"{key}: {keyRealName}, ");
- }
- sb.Append("name: name, ctx: _ctx);\n");
- }
-
- public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb)
- {
- sb.Append("public static ");
- int outputArgsCount = op.OutputArg.Count;
- if (outputArgsCount == 0)
- {
- sb.Append("Operation ");
- }
- else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
- {
- sb.Append("Tensor ");
- }
- else
- {
- sb.Append("Tensor[] ");
- }
- string opName = op.Name;
- string funcName = Utils.ConvertToUnderscore(op.Name);
- sb.Append($" {funcName}_eager_fallback(");
- AppendFallBackFunctionArgs(op, sb);
- sb.Append(")\n{\n");
-
- var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null);
- if (possibleRefArg is not null)
- {
- sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." +
- $" Arg '{possibleRefArg.Name}' is a ref.\");");
- sb.AppendLine("}"); // body
- return;
- }
-
- if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr)))
- {
- sb.AppendLine("List<Tensor> _inputs_flat_list = new();");
- foreach (var arg in op.InputArg)
- {
- string realArgName = arg.Name;
- if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
- {
- realArgName = $"{realArgName}_";
- }
- if (string.IsNullOrEmpty(arg.NumberAttr))
- {
- sb.AppendLine($"_inputs_flat_list.Add({realArgName});");
- }
- else
- {
- sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});");
- }
- }
- sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();");
- }
- else
- {
- sb.Append("Tensor[] _inputs_flat = new Tensor[]{");
- foreach (var arg in op.InputArg)
- {
- string realArgName = arg.Name;
- if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
- {
- realArgName = $"{realArgName}_";
- }
- sb.Append($"{realArgName}, ");
- }
- if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
- {
- sb.Remove(sb.Length - 2, 2);
- }
- sb.Append("};\n");
- }
-
- sb.Append("object[] _attrs = new object[]{");
- foreach (var attr in op.Attr)
- {
- if (attr.Type == "type")
- {
- bool found = false;
- foreach (var arg in op.InputArg)
- {
- string realArgName = arg.Name;
- if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
- {
- realArgName = $"{realArgName}_";
- }
- if (arg.TypeAttr == attr.Name)
- {
- sb.Append($"\"{attr.Name}\", {realArgName}.dtype, ");
- found = true;
- break;
- }
- }
- if (!found)
- {
- string attrRealName = attr.Name;
- if (SyntaxFactory.ParseToken(attrRealName).IsKeyword())
- {
- attrRealName = $"{attrRealName}_";
- }
- sb.Append($"\"{attr.Name}\", {attrRealName}, ");
- }
- }
- else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name))
- {
- bool found = false;
- foreach (var arg in op.InputArg)
- {
- string realArgName = arg.Name;
- if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
- {
- realArgName = $"{realArgName}_";
- }
- if (arg.NumberAttr == attr.Name)
- {
- sb.Append($"\"{attr.Name}\", {realArgName}.Length, ");
- found = true;
- break;
- }
- }
- }
- else
- {
- sb.Append($"\"{attr.Name}\", {attr.Name}, ");
- }
- }
- if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',')
- {
- sb.Remove(sb.Length - 2, 2);
- }
- sb.Append("};\n");
-
- sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " +
- $"attrs: _attrs, ctx: ctx, name: name);");
-
- sb.Append("if(_execute.must_record_gradient())\n{\n");
-
- sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);");
-
- sb.AppendLine("}"); // if
-
- if (outputArgsCount == 0)
- {
- sb.AppendLine("return null;");
- }
- else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
- {
- sb.AppendLine("return _result[0];");
- }
- else
- {
- sb.AppendLine("return _result;");
- }
-
- sb.AppendLine("}"); // body
- }
-
- public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb)
- {
- foreach (var arg in op.InputArg)
- {
- string argName = arg.Name;
- var token = SyntaxFactory.ParseToken(argName);
- if (token.IsKeyword())
- {
- argName = $"{argName}_";
- }
- if (!string.IsNullOrEmpty(arg.NumberAttr))
- {
- sb.Append($"Tensors {argName}, ");
- }
- else
- {
- sb.Append($"Tensor {argName}, ");
- }
- }
- var attrValueDic = GetAttrsDefaultValue(op, out var _);
- foreach (var (key, typeStr, _) in attrValueDic)
- {
- var token = SyntaxFactory.ParseToken(key);
- string realKey = key;
- if (token.IsKeyword())
- {
- realKey += "_";
- }
- sb.Append($"{typeStr} {realKey}, ");
- }
- sb.Append($"string name, Context ctx");
- }
-
- public void AppendOpHelperCall(OpDef op, StringBuilder sb)
- {
- sb.AppendLine("Dictionary<string, object> keywords = new();");
- foreach (var arg in op.InputArg)
- {
- string realArgName = arg.Name;
- if (SyntaxFactory.ParseToken(realArgName).IsKeyword())
- {
- realArgName += "_";
- }
- sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};");
- }
- var attrValueDic = GetAttrsDefaultValue(op, out var _);
- foreach (var (key, _, _) in attrValueDic)
- {
- sb.AppendLine($"keywords[\"{key}\"] = {key};");
- }
- sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);");
- }
-
- // name, type string, default value
- public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues)
- {
- dynamicDefaultValues = new();
- List<(string, string, string)> res = new();
- foreach (var attr in op.Attr)
- {
- if (attr.Type == "type")
- {
- bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name);
- if (!found)
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
- {
- string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype());
- string enumPath = typeof(TF_DataType).Name + "." + name;
- res.Add((attr.Name, "TF_DataType", enumPath));
- }
- else
- {
- res.Add((attr.Name, "TF_DataType", "NOVALUE"));
- }
- }
- }
- else if (attr.Type == "int")
- {
- if(op.InputArg.Any(x => x.NumberAttr == attr.Name))
- {
- continue;
- }
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I)
- {
- res.Add((attr.Name, "int", attr.DefaultValue.I.ToString()));
- }
- else
- {
- res.Add((attr.Name, "int", "0"));
- }
- }
- else if (attr.Type == "float")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F)
- {
- res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f"));
- }
- else
- {
- res.Add((attr.Name, "float", "NOVALUE"));
- }
- }
- else if (attr.Type == "string")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
- {
- res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""));
- }
- else
- {
- res.Add((attr.Name, "string", "NOVALUE"));
- }
- }
- else if (attr.Type == "bool")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B)
- {
- res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower()));
- }
- else
- {
- res.Add((attr.Name, "bool", "NOVALUE"));
- }
- }
- else if (attr.Type == "shape")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape)
- {
- if (attr.DefaultValue.Shape.UnknownRank)
- {
- res.Add((attr.Name, "Shape", $"null"));
- }
- else
- {
- Shape shape = new Shape(attr.DefaultValue.Shape);
- string expression = $"new Shape({string.Join(", ", shape.dims)})";
- dynamicDefaultValues[attr.Name] = expression;
- res.Add((attr.Name, "Shape", $"null"));
- }
- }
- else
- {
- res.Add((attr.Name, "Shape", "NOVALUE"));
- }
- }
- else if (attr.Type == "list(type)")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
- {
- List<TF_DataType> values = new();
- foreach (var value in attr.DefaultValue.List.Type)
- {
- values.Add(value.as_tf_dtype());
- }
- string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}";
- dynamicDefaultValues[attr.Name] = expression;
- res.Add((attr.Name, "TF_DataType[]", $"null"));
- }
- else
- {
- res.Add((attr.Name, "TF_DataType[]", "NOVALUE"));
- }
- }
- else if (attr.Type == "list(shape)")
- {
- res.Add((attr.Name, "Shape[]", "NOVALUE"));
- }
- else if (attr.Type == "list(string)")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
- {
- List<string> values = new();
- foreach (var value in attr.DefaultValue.List.S)
- {
- values.Add(value.ToStringUtf8());
- }
- string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}";
- dynamicDefaultValues[attr.Name] = expression;
- res.Add((attr.Name, "string[]", $"null"));
- }
- else
- {
- res.Add((attr.Name, "string[]", "NOVALUE"));
- }
- }
- else if (attr.Type == "list(int)")
- {
- if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
- {
- List<int> values = new();
- foreach(var value in attr.DefaultValue.List.I)
- {
- values.Add((int)value);
- }
- string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}";
- dynamicDefaultValues[attr.Name] = expression;
- res.Add((attr.Name, "int[]", $"null"));
- }
- else
- {
- res.Add((attr.Name, "int[]", "NOVALUE"));
- }
- }
- else if (attr.Type == "list(float)")
- {
- if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
- {
- List<float> values = new();
- foreach (var value in attr.DefaultValue.List.F)
- {
- values.Add(value);
- }
- string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}";
- dynamicDefaultValues[attr.Name] = expression;
- res.Add((attr.Name, "float[]", $"null"));
- }
- else
- {
- res.Add((attr.Name, "float[]", "NOVALUE"));
- }
- }
- else if (attr.Type == "func")
- {
- res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE"));
- }
- else if (attr.Type == "list(func)")
- {
- res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE"));
- }
- else if (attr.Type == "tensor")
- {
- res.Add((attr.Name, "TensorProto", "NOVALUE"));
- }
- else
- {
- throw new NotImplementedException();
- }
- }
- return res;
- }
-
- private static bool HasRefArgs(OpDef op)
- {
- return op.InputArg.Any(x => x.IsRef);
- }
- }
- }
|