using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection.Metadata.Ecma335; using System.Text; using System.Threading.Tasks; using Microsoft.CodeAnalysis.CSharp; namespace Tensorflow.CodeGen { public class FunctionGenerator { public void AppendFunction(OpDef op, StringBuilder sb) { // TODO: add descriptions sb.Append("public static "); int outputArgsCount = op.OutputArg.Count; if (outputArgsCount == 0) { sb.Append("Operation "); } else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.Append("Tensor "); } else { sb.Append("Tensor[] "); } string funcName = Utils.ConvertToUnderscore(op.Name); var token = SyntaxFactory.ParseToken(funcName); if (token.IsKeyword()) { funcName = $"_{funcName}"; } sb.Append($" {funcName}("); // define args AppendArgs(op, sb); sb.Append(")\n{\n"); // begin to write main body sb.AppendLine("var _ctx = tf.Context;"); var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); // deal with dynamic default values. foreach(var (name, expr) in dynamicDefaultValues) { sb.AppendLine($"if({name} is null)"); sb.AppendLine("{"); sb.AppendLine($"{name} = {expr};"); sb.AppendLine("}"); } sb.AppendLine("if(_ctx.executing_eagerly()){"); if(HasRefArgs(op)) { var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");"); } else { sb.Append("try\n{\n"); AppendFastPathExecute(op, sb); if (outputArgsCount == 0) { sb.AppendLine("return null;"); } else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.AppendLine("return _fast_path_result[0];"); } else { sb.AppendLine("return _fast_path_result;"); } sb.AppendLine("}"); // try sb.Append("catch(Exception)\n{\n"); sb.AppendLine("}"); // catch sb.Append("try\n{\n"); AppendEagerFallbackCall(op, sb); sb.AppendLine("}"); // try sb.Append("catch(Exception)\n{\n"); sb.AppendLine("}"); // catch } sb.AppendLine("}"); // if foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string")) { if(value != "NOVALUE") { sb.AppendLine($"if({name} is null)"); sb.AppendLine("{"); sb.AppendLine($"{name} = {value};"); sb.AppendLine("}"); } } // begin to use op helper. AppendOpHelperCall(op, sb); sb.AppendLine("var _result = _op.outputs;"); // check if it needs to record gradient. sb.Append("if(_execute.must_record_gradient())\n{\n"); sb.Append("object[] _attrs = new object[]{"); foreach (var attr in op.Attr) { string attrRealName = attr.Name; if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) { attrRealName += "_"; } if (attr.Type == "type") { sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), "); } else if (attr.Type == "int") { sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), "); } else if (attr.Type == "bool") { sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), "); } else { sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), "); } } if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { sb.Remove(sb.Length - 2, 2); } sb.Append("};\n"); sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);"); sb.AppendLine("}"); // if if (outputArgsCount == 0) { sb.AppendLine("return _op;"); } else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.AppendLine("return _result[0];"); } else { sb.AppendLine("return _result;"); } sb.AppendLine("}"); // body sb.AppendLine(); AppendEagerFallbackDefinition(op, sb); } public void AppendArgs(OpDef op, StringBuilder sb) { foreach (var arg in op.InputArg) { string argName = arg.Name; var token = SyntaxFactory.ParseToken(argName); if (token.IsKeyword()) { argName = $"{argName}_"; } if (!string.IsNullOrEmpty(arg.NumberAttr)) { sb.Append($"Tensors {argName}, "); } else { sb.Append($"Tensor {argName}, "); } } var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) { var token = SyntaxFactory.ParseToken(key); string realKey = key; if (token.IsKeyword()) { realKey += "_"; } sb.Append($"{typeStr} {realKey}, "); } foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE")) { var token = SyntaxFactory.ParseToken(key); string realKey = key; if (token.IsKeyword()) { realKey += "_"; } sb.Append($"{typeStr} {realKey} = {value}, "); } sb.Append($"string? name = null"); } public void AppendFastPathExecute(OpDef op, StringBuilder sb) { sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)"); sb.Append("{ args = new object[]{ "); foreach (var arg in op.InputArg) { string attrArgName = arg.Name; if (SyntaxFactory.ParseToken(attrArgName).IsKeyword()) { attrArgName += "_"; } sb.Append($"{attrArgName}, "); } if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { sb.Remove(sb.Length - 2, 2); } sb.Append("}, attrs = new Dictionary(){ "); var attrValueDic = GetAttrsDefaultValue(op, out var _); foreach (var (key, _, _) in attrValueDic) { sb.Append($"[\"{key}\"] = {key}, "); } if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { sb.Remove(sb.Length - 2, 2); } sb.Append("}});\n"); } public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) { string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback"; sb.Append($"return {funcName}("); foreach (var arg in op.InputArg) { string inputArgRealName = arg.Name; if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword()) { inputArgRealName += "_"; } sb.Append($"{inputArgRealName}, "); } var attrValueDic = GetAttrsDefaultValue(op, out var _); foreach (var (key, _, _) in attrValueDic) { string keyRealName = key; if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) { keyRealName += "_"; } sb.Append($"{key}: {keyRealName}, "); } sb.Append("name: name, ctx: _ctx);\n"); } public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) { sb.Append("public static "); int outputArgsCount = op.OutputArg.Count; if (outputArgsCount == 0) { sb.Append("Operation "); } else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.Append("Tensor "); } else { sb.Append("Tensor[] "); } string opName = op.Name; string funcName = Utils.ConvertToUnderscore(op.Name); sb.Append($" {funcName}_eager_fallback("); AppendFallBackFunctionArgs(op, sb); sb.Append(")\n{\n"); var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); if (possibleRefArg is not null) { sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." + $" Arg '{possibleRefArg.Name}' is a ref.\");"); sb.AppendLine("}"); // body return; } if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr))) { sb.AppendLine("List _inputs_flat_list = new();"); foreach (var arg in op.InputArg) { string realArgName = arg.Name; if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) { realArgName = $"{realArgName}_"; } if (string.IsNullOrEmpty(arg.NumberAttr)) { sb.AppendLine($"_inputs_flat_list.Add({realArgName});"); } else { sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});"); } } sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();"); } else { sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); foreach (var arg in op.InputArg) { string realArgName = arg.Name; if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) { realArgName = $"{realArgName}_"; } sb.Append($"{realArgName}, "); } if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { sb.Remove(sb.Length - 2, 2); } sb.Append("};\n"); } sb.Append("object[] _attrs = new object[]{"); foreach (var attr in op.Attr) { if (attr.Type == "type") { bool found = false; foreach (var arg in op.InputArg) { string realArgName = arg.Name; if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) { realArgName = $"{realArgName}_"; } if (arg.TypeAttr == attr.Name) { sb.Append($"\"{attr.Name}\", {realArgName}.dtype, "); found = true; break; } } if (!found) { string attrRealName = attr.Name; if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) { attrRealName = $"{attrRealName}_"; } sb.Append($"\"{attr.Name}\", {attrRealName}, "); } } else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) { bool found = false; foreach (var arg in op.InputArg) { string realArgName = arg.Name; if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) { realArgName = $"{realArgName}_"; } if (arg.NumberAttr == attr.Name) { sb.Append($"\"{attr.Name}\", {realArgName}.Length, "); found = true; break; } } } else { sb.Append($"\"{attr.Name}\", {attr.Name}, "); } } if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { sb.Remove(sb.Length - 2, 2); } sb.Append("};\n"); sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " + $"attrs: _attrs, ctx: ctx, name: name);"); sb.Append("if(_execute.must_record_gradient())\n{\n"); sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);"); sb.AppendLine("}"); // if if (outputArgsCount == 0) { sb.AppendLine("return null;"); } else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.AppendLine("return _result[0];"); } else { sb.AppendLine("return _result;"); } sb.AppendLine("}"); // body } public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb) { foreach (var arg in op.InputArg) { string argName = arg.Name; var token = SyntaxFactory.ParseToken(argName); if (token.IsKeyword()) { argName = $"{argName}_"; } if (!string.IsNullOrEmpty(arg.NumberAttr)) { sb.Append($"Tensors {argName}, "); } else { sb.Append($"Tensor {argName}, "); } } var attrValueDic = GetAttrsDefaultValue(op, out var _); foreach (var (key, typeStr, _) in attrValueDic) { var token = SyntaxFactory.ParseToken(key); string realKey = key; if (token.IsKeyword()) { realKey += "_"; } sb.Append($"{typeStr} {realKey}, "); } sb.Append($"string name, Context ctx"); } public void AppendOpHelperCall(OpDef op, StringBuilder sb) { sb.AppendLine("Dictionary keywords = new();"); foreach (var arg in op.InputArg) { string realArgName = arg.Name; if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) { realArgName += "_"; } sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); } var attrValueDic = GetAttrsDefaultValue(op, out var _); foreach (var (key, _, _) in attrValueDic) { sb.AppendLine($"keywords[\"{key}\"] = {key};"); } sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); } // name, type string, default value public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary dynamicDefaultValues) { dynamicDefaultValues = new(); List<(string, string, string)> res = new(); foreach (var attr in op.Attr) { if (attr.Type == "type") { bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); if (!found) { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) { string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); string enumPath = typeof(TF_DataType).Name + "." + name; res.Add((attr.Name, "TF_DataType", enumPath)); } else { res.Add((attr.Name, "TF_DataType", "NOVALUE")); } } } else if (attr.Type == "int") { if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) { continue; } if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) { res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); } else { res.Add((attr.Name, "int", "0")); } } else if (attr.Type == "float") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) { res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); } else { res.Add((attr.Name, "float", "NOVALUE")); } } else if (attr.Type == "string") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) { res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); } else { res.Add((attr.Name, "string", "NOVALUE")); } } else if (attr.Type == "bool") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) { res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); } else { res.Add((attr.Name, "bool", "NOVALUE")); } } else if (attr.Type == "shape") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) { if (attr.DefaultValue.Shape.UnknownRank) { res.Add((attr.Name, "Shape", $"null")); } else { Shape shape = new Shape(attr.DefaultValue.Shape); string expression = $"new Shape({string.Join(", ", shape.dims)})"; dynamicDefaultValues[attr.Name] = expression; res.Add((attr.Name, "Shape", $"null")); } } else { res.Add((attr.Name, "Shape", "NOVALUE")); } } else if (attr.Type == "list(type)") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) { List values = new(); foreach (var value in attr.DefaultValue.List.Type) { values.Add(value.as_tf_dtype()); } string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; dynamicDefaultValues[attr.Name] = expression; res.Add((attr.Name, "TF_DataType[]", $"null")); } else { res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); } } else if (attr.Type == "list(shape)") { res.Add((attr.Name, "Shape[]", "NOVALUE")); } else if (attr.Type == "list(string)") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) { List values = new(); foreach (var value in attr.DefaultValue.List.S) { values.Add(value.ToStringUtf8()); } string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; dynamicDefaultValues[attr.Name] = expression; res.Add((attr.Name, "string[]", $"null")); } else { res.Add((attr.Name, "string[]", "NOVALUE")); } } else if (attr.Type == "list(int)") { if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) { List values = new(); foreach(var value in attr.DefaultValue.List.I) { values.Add((int)value); } string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; dynamicDefaultValues[attr.Name] = expression; res.Add((attr.Name, "int[]", $"null")); } else { res.Add((attr.Name, "int[]", "NOVALUE")); } } else if (attr.Type == "list(float)") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) { List values = new(); foreach (var value in attr.DefaultValue.List.F) { values.Add(value); } string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; dynamicDefaultValues[attr.Name] = expression; res.Add((attr.Name, "float[]", $"null")); } else { res.Add((attr.Name, "float[]", "NOVALUE")); } } else if (attr.Type == "func") { res.Add((attr.Name, "Func", "NOVALUE")); } else if (attr.Type == "list(func)") { res.Add((attr.Name, "Func[]", "NOVALUE")); } else if (attr.Type == "tensor") { res.Add((attr.Name, "TensorProto", "NOVALUE")); } else { throw new NotImplementedException(); } } return res; } private static bool HasRefArgs(OpDef op) { return op.InputArg.Any(x => x.IsRef); } } }