using System; using System.Collections.Generic; using System.ComponentModel; using System.Dynamic; using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; using static Tensorflow.OpDef.Types; namespace Tensorflow { public class OpDefLibrary : Python { public Operation _apply_op_helper(string op_type_name, string name = null, object args = null) { Dictionary keywords = ConvertToDict(args); var g = ops.get_default_graph(); var op_def = g.GetOpDef(op_type_name); // Default name if not specified. if (String.IsNullOrEmpty(name)) name = op_type_name; // Check for deprecation if (op_def.Deprecation != null && op_def.Deprecation.Version > 0) { } var default_type_attr_map = new Dictionary(); foreach (var attr_def in op_def.Attr) { if (attr_def.Type != "type") continue; var key = attr_def.Name; if (attr_def.DefaultValue != null) { default_type_attr_map[key] = attr_def.DefaultValue.Type; } } var attrs = new Dictionary(); var inputs = new List(); var input_types = new List(); object values = null; return with(ops.name_scope(name), scope => { var inferred_from = new Dictionary(); var base_types = new List(); var types = new List(); // Perform input type inference foreach (var input_arg in op_def.InputArg) { var input_name = input_arg.Name; if (keywords.ContainsKey(input_name)) values = keywords[input_name]; else if (keywords.ContainsKey(input_name + "_")) { input_name += "_"; values = keywords[input_name]; } else throw new TypeError("No argument for input " + input_name); // Goals: // * Convert values to Tensors if it contains constants. // * Verify that values is a list if that matches the input_arg's // type. // * If the input_arg's type is determined by attrs, either set // those attrs and validate those attr values are legal (if // they have not yet been set) or validate the input matches // the type indicated by the attrs (if they have already been // inferred via an earlier input). // * If the input_arg has an explicit type, make sure the input // conforms. DataType dtype = DataType.DtInvalid; DataType default_dtype = DataType.DtInvalid; if (_IsListParameter(input_arg)) { if (!_IsListValue(values)) throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); if(input_arg.Type != DataType.DtInvalid) dtype = input_arg.Type; else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) { if (attrs.ContainsKey(input_arg.TypeAttr)) dtype = (DataType)attrs[input_arg.TypeAttr]; else if (values is Tensor[] values1) dtype = values1[0].dtype.as_datatype_enum(); if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; } if(input_arg.IsRef && dtype != DataType.DtInvalid) dtype = dtype.as_base_dtype(); values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype.as_tf_dtype(), preferred_dtype: default_dtype.as_tf_dtype(), as_ref: input_arg.IsRef); } else { if (input_arg.Type != DataType.DtInvalid) dtype = input_arg.Type; else if (attrs.ContainsKey(input_arg.TypeAttr)) dtype = (DataType)attrs[input_arg.TypeAttr]; else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; var value = ops.internal_convert_to_tensor(values, name: input_name, dtype: dtype.as_tf_dtype(), as_ref: input_arg.IsRef, preferred_dtype: default_dtype.as_tf_dtype()); //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) //attrs[input_arg.TypeAttr] = values.dtype; values = new Tensor[] { value }; } if (values is Tensor[] values2) { types = values2.Select(x => x.dtype).ToList(); inputs.AddRange(values2); base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList(); } else throw new NotImplementedException("_IsListParameter"); SetAttrs(op_type_name, input_arg, op_def, attrs, inferred_from, types, base_types, input_types, values); } // Process remaining attrs foreach (var attr in op_def.Attr) { if (keywords.ContainsKey(attr.Name)) { attrs[attr.Name] = keywords[attr.Name]; } } // Convert attr values to AttrValue protos. var attr_protos = new Dictionary(); foreach (var attr_def in op_def.Attr) { var key = attr_def.Name; var value = attrs[key]; if (!attrs.ContainsKey(key)) Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def."); attr_protos[key] = SetAttrValue(op_def, attr_def, value); } attrs.Clear(); // Determine output types (possibly using attrs) var output_types = new List(); foreach (var arg in op_def.OutputArg) { types = new List(); if (!string.IsNullOrEmpty(arg.NumberAttr)) { } else if (!string.IsNullOrEmpty(arg.TypeAttr)) { types = new List() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; } if (arg.IsRef) types = types.Select(x => x.as_ref()).ToList(); output_types.AddRange(types); } // Add Op to graph var op = g.create_op(op_type_name, inputs.ToArray(), output_types.ToArray(), name: scope, input_types: input_types.ToArray(), attrs: attr_protos, op_def: op_def); return op; }); } private void SetAttrs(string op_type_name, ArgDef input_arg, OpDef op_def, Dictionary attrs, Dictionary inferred_from, List types, List base_types, List input_types, dynamic values) { var input_name = input_arg.Name; if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { if (attrs.ContainsKey(input_arg.NumberAttr)) { } else { attrs[input_arg.NumberAttr] = (values as Tensor[]).Length; inferred_from[input_arg.NumberAttr] = input_name; var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr); if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum) throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " + $"than minimum length {num_attr.Minimum}"); } // All tensors must have the same base type. if (input_arg.Type != DataType.DtInvalid) { } else { attrs[input_arg.TypeAttr] = base_types[0]; inferred_from[input_arg.TypeAttr] = input_name; var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); } } else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) { var attr_value = base_types[0]; if (attrs.ContainsKey(input_arg.TypeAttr)) { } else { attrs[input_arg.TypeAttr] = attr_value; inferred_from[input_arg.TypeAttr] = input_name; } } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { var attr_value = base_types; if (attrs.ContainsKey(input_arg.TypeListAttr)) { } else { attrs[input_arg.TypeListAttr] = attr_value; inferred_from[input_arg.TypeListAttr] = input_name; } } if (input_arg.IsRef) input_types.AddRange(types); else input_types.AddRange(base_types); } public DataType _MakeType(TF_DataType v, AttrDef attr_def) { return v.as_base_dtype().as_datatype_enum(); } private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value) { var attr_value = new AttrValue(); if (attr_def.Type.StartsWith("list(")) { if (attr_def.HasMinimum) ; attr_value.List = new AttrValue.Types.ListValue(); } switch (attr_def.Type) { case "string": attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); break; case "type": attr_value.Type = _MakeType((TF_DataType)value, attr_def); break; case "list(type)": attr_value.List.Type.AddRange((value as IList).Select(x => _MakeType(x, attr_def))); break; case "list(int)": attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); break; case "bool": attr_value.B = (bool)value; break; case "float": attr_value.F = (float)value; break; case "int": attr_value.I = (int)value; if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum) throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); break; case "shape": if (value == null && attr_def.DefaultValue != null) attr_value.Shape = attr_def.DefaultValue.Shape; if(value is TensorShape val1) attr_value.Shape = val1.as_proto(); else if(value is long[] val2) attr_value.Shape = tensor_util.as_shape(val2); else if (value is int[] val3) attr_value.Shape = tensor_util.as_shape(val3); break; default: throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); } return attr_value; } private bool _IsListParameter(ArgDef arg) { if (!String.IsNullOrEmpty(arg.NumberAttr)) return true; else if (!String.IsNullOrEmpty(arg.TypeListAttr)) return true; else return false; } private bool _IsListValue(object v) { switch (v) { case Tensor[] val: return true; default: return false; } } } }