|
|
@@ -60,201 +60,204 @@ namespace Tensorflow |
|
|
|
object values = null; |
|
|
|
|
|
|
|
g.as_default(); |
|
|
|
var ret_op = tf_with(ops.name_scope(name), scope => |
|
|
|
{ |
|
|
|
var inferred_from = new Dictionary<string, object>(); |
|
|
|
var base_types = new List<TF_DataType>(); |
|
|
|
var types = new List<TF_DataType>(); |
|
|
|
string _scope_name = scope; |
|
|
|
|
|
|
|
// Perform input type inference |
|
|
|
foreach (var (i, input_arg) in enumerate(op_def.InputArg)) |
|
|
|
{ |
|
|
|
var input_name = input_arg.Name; |
|
|
|
var scope = ops.name_scope(name); |
|
|
|
scope.__enter__(); |
|
|
|
|
|
|
|
var inferred_from = new Dictionary<string, object>(); |
|
|
|
var base_types = new List<TF_DataType>(); |
|
|
|
var types = new List<TF_DataType>(); |
|
|
|
string _scope_name = scope; |
|
|
|
|
|
|
|
// Perform input type inference |
|
|
|
foreach (var (i, input_arg) in enumerate(op_def.InputArg)) |
|
|
|
{ |
|
|
|
var input_name = input_arg.Name; |
|
|
|
|
|
|
|
if (keywords.ContainsKey(input_name)) |
|
|
|
values = keywords[input_name]; |
|
|
|
else if (keywords.ContainsKey(input_name + "_")) |
|
|
|
{ |
|
|
|
input_name += "_"; |
|
|
|
values = keywords[input_name]; |
|
|
|
} |
|
|
|
else if (keywords.ContainsKey($"input_{i}")) |
|
|
|
{ |
|
|
|
values = keywords[$"input_{i}"]; |
|
|
|
} |
|
|
|
else |
|
|
|
throw new TypeError("No argument for input " + input_name); |
|
|
|
|
|
|
|
// Goals: |
|
|
|
// * Convert values to Tensors if it contains constants. |
|
|
|
// * Verify that values is a list if that matches the input_arg's |
|
|
|
// type. |
|
|
|
// * If the input_arg's type is determined by attrs, either set |
|
|
|
// those attrs and validate those attr values are legal (if |
|
|
|
// they have not yet been set) or validate the input matches |
|
|
|
// the type indicated by the attrs (if they have already been |
|
|
|
// inferred via an earlier input). |
|
|
|
// * If the input_arg has an explicit type, make sure the input |
|
|
|
// conforms. |
|
|
|
|
|
|
|
DataType dtype = DataType.DtInvalid; |
|
|
|
DataType default_dtype = DataType.DtInvalid; |
|
|
|
|
|
|
|
if (_IsListParameter(input_arg)) |
|
|
|
{ |
|
|
|
if (!_IsListValue(values)) |
|
|
|
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); |
|
|
|
if (input_arg.Type != DataType.DtInvalid) |
|
|
|
dtype = input_arg.Type; |
|
|
|
else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
|
{ |
|
|
|
if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
dtype = (DataType)attrs[input_arg.TypeAttr]; |
|
|
|
else |
|
|
|
switch (values) |
|
|
|
{ |
|
|
|
case Tensor[] values1: |
|
|
|
dtype = values1[0].dtype.as_datatype_enum(); |
|
|
|
break; |
|
|
|
case object[] values1: |
|
|
|
foreach (var t in values1) |
|
|
|
if (t is Tensor tensor) |
|
|
|
{ |
|
|
|
dtype = tensor.dtype.as_datatype_enum(); |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
default: |
|
|
|
throw new NotImplementedException($"can't infer the dtype for {values.GetType()}"); |
|
|
|
} |
|
|
|
|
|
|
|
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
|
|
} |
|
|
|
|
|
|
|
if (!input_arg.IsRef && dtype != DataType.DtInvalid) |
|
|
|
dtype = dtype.as_base_dtype(); |
|
|
|
|
|
|
|
values = ops.internal_convert_n_to_tensor(values as object[], |
|
|
|
name: input_arg.Name, |
|
|
|
dtype: dtype.as_tf_dtype(), |
|
|
|
preferred_dtype: default_dtype.as_tf_dtype(), |
|
|
|
as_ref: input_arg.IsRef); |
|
|
|
} |
|
|
|
else |
|
|
|
if (keywords.ContainsKey(input_name)) |
|
|
|
values = keywords[input_name]; |
|
|
|
else if (keywords.ContainsKey(input_name + "_")) |
|
|
|
{ |
|
|
|
input_name += "_"; |
|
|
|
values = keywords[input_name]; |
|
|
|
} |
|
|
|
else if (keywords.ContainsKey($"input_{i}")) |
|
|
|
{ |
|
|
|
values = keywords[$"input_{i}"]; |
|
|
|
} |
|
|
|
else |
|
|
|
throw new TypeError("No argument for input " + input_name); |
|
|
|
|
|
|
|
// Goals: |
|
|
|
// * Convert values to Tensors if it contains constants. |
|
|
|
// * Verify that values is a list if that matches the input_arg's |
|
|
|
// type. |
|
|
|
// * If the input_arg's type is determined by attrs, either set |
|
|
|
// those attrs and validate those attr values are legal (if |
|
|
|
// they have not yet been set) or validate the input matches |
|
|
|
// the type indicated by the attrs (if they have already been |
|
|
|
// inferred via an earlier input). |
|
|
|
// * If the input_arg has an explicit type, make sure the input |
|
|
|
// conforms. |
|
|
|
|
|
|
|
DataType dtype = DataType.DtInvalid; |
|
|
|
DataType default_dtype = DataType.DtInvalid; |
|
|
|
|
|
|
|
if (_IsListParameter(input_arg)) |
|
|
|
{ |
|
|
|
if (!_IsListValue(values)) |
|
|
|
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); |
|
|
|
if (input_arg.Type != DataType.DtInvalid) |
|
|
|
dtype = input_arg.Type; |
|
|
|
else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
|
{ |
|
|
|
if (input_arg.Type != DataType.DtInvalid) |
|
|
|
dtype = input_arg.Type; |
|
|
|
else if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
dtype = (DataType)attrs[input_arg.TypeAttr]; |
|
|
|
else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid) |
|
|
|
dtype = DataType.DtString; |
|
|
|
else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
|
else |
|
|
|
switch (values) |
|
|
|
{ |
|
|
|
case Tensor[] values1: |
|
|
|
dtype = values1[0].dtype.as_datatype_enum(); |
|
|
|
break; |
|
|
|
case object[] values1: |
|
|
|
foreach (var t in values1) |
|
|
|
if (t is Tensor tensor) |
|
|
|
{ |
|
|
|
dtype = tensor.dtype.as_datatype_enum(); |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
default: |
|
|
|
throw new NotImplementedException($"can't infer the dtype for {values.GetType()}"); |
|
|
|
} |
|
|
|
|
|
|
|
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
|
|
} |
|
|
|
|
|
|
|
var value = ops.convert_to_tensor(values, |
|
|
|
name: input_name, |
|
|
|
dtype: dtype.as_tf_dtype(), |
|
|
|
as_ref: input_arg.IsRef, |
|
|
|
preferred_dtype: default_dtype.as_tf_dtype()); |
|
|
|
|
|
|
|
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
//attrs[input_arg.TypeAttr] = values.dtype; |
|
|
|
if (!input_arg.IsRef && dtype != DataType.DtInvalid) |
|
|
|
dtype = dtype.as_base_dtype(); |
|
|
|
|
|
|
|
values = new Tensor[] { value }; |
|
|
|
} |
|
|
|
values = ops.internal_convert_n_to_tensor(values as object[], |
|
|
|
name: input_arg.Name, |
|
|
|
dtype: dtype.as_tf_dtype(), |
|
|
|
preferred_dtype: default_dtype.as_tf_dtype(), |
|
|
|
as_ref: input_arg.IsRef); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
if (input_arg.Type != DataType.DtInvalid) |
|
|
|
dtype = input_arg.Type; |
|
|
|
else if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
dtype = (DataType)attrs[input_arg.TypeAttr]; |
|
|
|
else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid) |
|
|
|
dtype = DataType.DtString; |
|
|
|
else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
|
|
|
|
|
|
var value = ops.convert_to_tensor(values, |
|
|
|
name: input_name, |
|
|
|
dtype: dtype.as_tf_dtype(), |
|
|
|
as_ref: input_arg.IsRef, |
|
|
|
preferred_dtype: default_dtype.as_tf_dtype()); |
|
|
|
|
|
|
|
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
//attrs[input_arg.TypeAttr] = values.dtype; |
|
|
|
|
|
|
|
values = new Tensor[] { value }; |
|
|
|
} |
|
|
|
|
|
|
|
if (values is Tensor[] values2) |
|
|
|
{ |
|
|
|
types = values2.Select(x => x.dtype).ToList(); |
|
|
|
inputs.AddRange(values2); |
|
|
|
base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList(); |
|
|
|
} |
|
|
|
else throw new NotImplementedException("_IsListParameter"); |
|
|
|
|
|
|
|
SetAttrs(op_type_name, |
|
|
|
input_arg, |
|
|
|
op_def, |
|
|
|
attrs, |
|
|
|
inferred_from, |
|
|
|
types, |
|
|
|
base_types, |
|
|
|
input_types, |
|
|
|
values); |
|
|
|
if (values is Tensor[] values2) |
|
|
|
{ |
|
|
|
types = values2.Select(x => x.dtype).ToList(); |
|
|
|
inputs.AddRange(values2); |
|
|
|
base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList(); |
|
|
|
} |
|
|
|
else throw new NotImplementedException("_IsListParameter"); |
|
|
|
|
|
|
|
SetAttrs(op_type_name, |
|
|
|
input_arg, |
|
|
|
op_def, |
|
|
|
attrs, |
|
|
|
inferred_from, |
|
|
|
types, |
|
|
|
base_types, |
|
|
|
input_types, |
|
|
|
values); |
|
|
|
} |
|
|
|
|
|
|
|
// Process remaining attrs |
|
|
|
foreach (var attr in op_def.Attr) |
|
|
|
// Process remaining attrs |
|
|
|
foreach (var attr in op_def.Attr) |
|
|
|
{ |
|
|
|
if (keywords.ContainsKey(attr.Name)) |
|
|
|
{ |
|
|
|
if (keywords.ContainsKey(attr.Name)) |
|
|
|
{ |
|
|
|
attrs[attr.Name] = keywords[attr.Name]; |
|
|
|
} |
|
|
|
attrs[attr.Name] = keywords[attr.Name]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Convert attr values to AttrValue protos. |
|
|
|
var attr_protos = new Dictionary<string, AttrValue>(); |
|
|
|
foreach (AttrDef attr_def in op_def.Attr) |
|
|
|
// Convert attr values to AttrValue protos. |
|
|
|
var attr_protos = new Dictionary<string, AttrValue>(); |
|
|
|
foreach (AttrDef attr_def in op_def.Attr) |
|
|
|
{ |
|
|
|
var key = attr_def.Name; |
|
|
|
if (attrs.ContainsKey(key)) |
|
|
|
{ |
|
|
|
var key = attr_def.Name; |
|
|
|
if (attrs.ContainsKey(key)) |
|
|
|
{ |
|
|
|
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); |
|
|
|
} |
|
|
|
else |
|
|
|
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
if (attr_def.DefaultValue == null) |
|
|
|
{ |
|
|
|
if (attr_def.DefaultValue == null) |
|
|
|
{ |
|
|
|
throw new TypeError("Missing required positional argument " + key); |
|
|
|
} |
|
|
|
throw new TypeError("Missing required positional argument " + key); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
attrs.Clear(); |
|
|
|
attrs.Clear(); |
|
|
|
|
|
|
|
// Determine output types (possibly using attrs) |
|
|
|
var output_types = new List<TF_DataType>(); |
|
|
|
// Determine output types (possibly using attrs) |
|
|
|
var output_types = new List<TF_DataType>(); |
|
|
|
|
|
|
|
foreach (var arg in op_def.OutputArg) |
|
|
|
foreach (var arg in op_def.OutputArg) |
|
|
|
{ |
|
|
|
types = new List<TF_DataType>(); |
|
|
|
if (!string.IsNullOrEmpty(arg.NumberAttr)) |
|
|
|
{ |
|
|
|
types = new List<TF_DataType>(); |
|
|
|
if (!string.IsNullOrEmpty(arg.NumberAttr)) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
else if (!string.IsNullOrEmpty(arg.TypeAttr)) |
|
|
|
{ |
|
|
|
types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; |
|
|
|
} |
|
|
|
} |
|
|
|
else if (!string.IsNullOrEmpty(arg.TypeAttr)) |
|
|
|
{ |
|
|
|
types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; |
|
|
|
} |
|
|
|
|
|
|
|
if (arg.IsRef) |
|
|
|
types = types.Select(x => x.as_ref()).ToList(); |
|
|
|
if (arg.IsRef) |
|
|
|
types = types.Select(x => x.as_ref()).ToList(); |
|
|
|
|
|
|
|
output_types.AddRange(types); |
|
|
|
} |
|
|
|
output_types.AddRange(types); |
|
|
|
} |
|
|
|
|
|
|
|
// We add an explicit colocation constraint between |
|
|
|
// the newly created op and any of its reference-typed inputs. |
|
|
|
var must_colocate_inputs = zip(op_def.InputArg, inputs) |
|
|
|
.Where(x => x.Item1.IsRef) |
|
|
|
.Select(x => x.Item2) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
_MaybeColocateWith(must_colocate_inputs); |
|
|
|
|
|
|
|
// Add Op to graph |
|
|
|
var ret_op = g.create_op(op_type_name, |
|
|
|
inputs.ToArray(), |
|
|
|
output_types.ToArray(), |
|
|
|
name: _scope_name, |
|
|
|
input_types: input_types.ToArray(), |
|
|
|
attrs: attr_protos, |
|
|
|
op_def: op_def); |
|
|
|
|
|
|
|
scope.__exit__(); |
|
|
|
|
|
|
|
// We add an explicit colocation constraint between |
|
|
|
// the newly created op and any of its reference-typed inputs. |
|
|
|
var must_colocate_inputs = zip(op_def.InputArg, inputs) |
|
|
|
.Where(x => x.Item1.IsRef) |
|
|
|
.Select(x => x.Item2) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
_MaybeColocateWith(must_colocate_inputs); |
|
|
|
|
|
|
|
// Add Op to graph |
|
|
|
var op = g.create_op(op_type_name, |
|
|
|
inputs.ToArray(), |
|
|
|
output_types.ToArray(), |
|
|
|
name: _scope_name, |
|
|
|
input_types: input_types.ToArray(), |
|
|
|
attrs: attr_protos, |
|
|
|
op_def: op_def); |
|
|
|
|
|
|
|
return op; |
|
|
|
}); |
|
|
|
g.Exit(); |
|
|
|
|
|
|
|
return ret_op; |
|
|
|
} |
|
|
|
|
|
|
|