|
@@ -40,13 +40,15 @@ namespace Tensorflow |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var attrs = new Dictionary<string, object>(); |
|
|
var attrs = new Dictionary<string, object>(); |
|
|
var inferred_from = new Dictionary<string, object>(); |
|
|
|
|
|
var inputs = new List<Tensor>(); |
|
|
var inputs = new List<Tensor>(); |
|
|
var input_types = new List<TF_DataType>(); |
|
|
var input_types = new List<TF_DataType>(); |
|
|
var base_types = new List<TF_DataType>(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope => |
|
|
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope => |
|
|
{ |
|
|
{ |
|
|
|
|
|
var inferred_from = new Dictionary<string, object>(); |
|
|
|
|
|
var base_types = new List<TF_DataType>(); |
|
|
|
|
|
var types = new List<TF_DataType>(); |
|
|
|
|
|
|
|
|
// Perform input type inference |
|
|
// Perform input type inference |
|
|
foreach (var input_arg in op_def.InputArg) |
|
|
foreach (var input_arg in op_def.InputArg) |
|
|
{ |
|
|
{ |
|
@@ -72,20 +74,14 @@ namespace Tensorflow |
|
|
if (!_IsListValue(values)) |
|
|
if (!_IsListValue(values)) |
|
|
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); |
|
|
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); |
|
|
if(input_arg.Type != DataType.DtInvalid) |
|
|
if(input_arg.Type != DataType.DtInvalid) |
|
|
{ |
|
|
|
|
|
dtype = input_arg.Type; |
|
|
dtype = input_arg.Type; |
|
|
} |
|
|
|
|
|
else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
{ |
|
|
{ |
|
|
if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
{ |
|
|
|
|
|
dtype = (DataType)attrs[input_arg.TypeAttr]; |
|
|
dtype = (DataType)attrs[input_arg.TypeAttr]; |
|
|
} |
|
|
|
|
|
else |
|
|
else |
|
|
{ |
|
|
|
|
|
if (values is Tensor[] values1) |
|
|
if (values is Tensor[] values1) |
|
|
dtype = values1[0].dtype.as_datatype_enum(); |
|
|
dtype = values1[0].dtype.as_datatype_enum(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
@@ -94,86 +90,48 @@ namespace Tensorflow |
|
|
if(input_arg.IsRef && dtype != DataType.DtInvalid) |
|
|
if(input_arg.IsRef && dtype != DataType.DtInvalid) |
|
|
dtype = dtype.as_base_dtype(); |
|
|
dtype = dtype.as_base_dtype(); |
|
|
|
|
|
|
|
|
values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef); |
|
|
|
|
|
|
|
|
values = ops.internal_convert_n_to_tensor(values, |
|
|
|
|
|
name: input_arg.Name, |
|
|
|
|
|
dtype: dtype, |
|
|
|
|
|
preferred_dtype: default_dtype, |
|
|
|
|
|
as_ref: input_arg.IsRef); |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
|
|
|
|
|
|
if (input_arg.Type != DataType.DtInvalid) |
|
|
|
|
|
dtype = input_arg.Type; |
|
|
|
|
|
else if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
|
|
dtype = (DataType)attrs[input_arg.TypeAttr]; |
|
|
|
|
|
else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) |
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
|
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; |
|
|
|
|
|
|
|
|
if (keywords[input_name] is Tensor) |
|
|
|
|
|
{ |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name, as_ref: input_arg.IsRef); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype; |
|
|
|
|
|
} |
|
|
|
|
|
values = new Tensor[] { keywords[input_name] as Tensor }; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
inputs.AddRange(values as Tensor[]); |
|
|
|
|
|
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); |
|
|
|
|
|
input_types.AddRange(base_types); |
|
|
|
|
|
|
|
|
|
|
|
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
if (attrs.ContainsKey(input_arg.NumberAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.NumberAttr] = (values as Tensor[]).Length; |
|
|
|
|
|
inferred_from[input_arg.NumberAttr] = input_name; |
|
|
|
|
|
var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr); |
|
|
|
|
|
if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum) |
|
|
|
|
|
throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " + |
|
|
|
|
|
$"than minimum length {num_attr.Minimum}"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
values = ops.internal_convert_to_tensor(values, |
|
|
|
|
|
name: input_name, |
|
|
|
|
|
as_ref: input_arg.IsRef); |
|
|
|
|
|
|
|
|
// All tensors must have the same base type. |
|
|
|
|
|
if(input_arg.Type != DataType.DtInvalid) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
|
|
//attrs[input_arg.TypeAttr] = values.dtype; |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeAttr] = base_types[0]; |
|
|
|
|
|
inferred_from[input_arg.TypeAttr] = input_name; |
|
|
|
|
|
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
values = new Tensor[] { values }; |
|
|
} |
|
|
} |
|
|
else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
var attr_value = base_types[0]; |
|
|
|
|
|
if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeAttr] = attr_value; |
|
|
|
|
|
inferred_from[input_arg.TypeAttr] = input_name; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) |
|
|
|
|
|
|
|
|
if (values is Tensor[] values2) |
|
|
{ |
|
|
{ |
|
|
var attr_value = base_types; |
|
|
|
|
|
if (attrs.ContainsKey(input_arg.TypeListAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeListAttr] = attr_value; |
|
|
|
|
|
inferred_from[input_arg.TypeListAttr] = input_name; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
types = values2.Select(x => x.dtype).ToList(); |
|
|
|
|
|
inputs.AddRange(values2); |
|
|
|
|
|
base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList(); |
|
|
} |
|
|
} |
|
|
|
|
|
else throw new NotImplementedException("_IsListParameter"); |
|
|
|
|
|
|
|
|
|
|
|
SetAttrs(op_type_name, |
|
|
|
|
|
input_arg, |
|
|
|
|
|
op_def, |
|
|
|
|
|
attrs, |
|
|
|
|
|
inferred_from, |
|
|
|
|
|
types, |
|
|
|
|
|
base_types, |
|
|
|
|
|
input_types, |
|
|
|
|
|
values); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Process remaining attrs |
|
|
// Process remaining attrs |
|
@@ -190,22 +148,26 @@ namespace Tensorflow |
|
|
foreach (var attr_def in op_def.Attr) |
|
|
foreach (var attr_def in op_def.Attr) |
|
|
{ |
|
|
{ |
|
|
var key = attr_def.Name; |
|
|
var key = attr_def.Name; |
|
|
|
|
|
var value = attrs[key]; |
|
|
|
|
|
|
|
|
if (!attrs.ContainsKey(key)) |
|
|
if (!attrs.ContainsKey(key)) |
|
|
Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def."); |
|
|
Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def."); |
|
|
|
|
|
|
|
|
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); |
|
|
|
|
|
|
|
|
attr_protos[key] = SetAttrValue(op_def, attr_def, value); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
attrs.Clear(); |
|
|
|
|
|
|
|
|
// Determine output types (possibly using attrs) |
|
|
// Determine output types (possibly using attrs) |
|
|
var output_types = new List<TF_DataType>(); |
|
|
var output_types = new List<TF_DataType>(); |
|
|
|
|
|
|
|
|
foreach (var arg in op_def.OutputArg) |
|
|
foreach (var arg in op_def.OutputArg) |
|
|
{ |
|
|
{ |
|
|
if (!String.IsNullOrEmpty(arg.NumberAttr)) |
|
|
|
|
|
|
|
|
if (!string.IsNullOrEmpty(arg.NumberAttr)) |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
else if (!String.IsNullOrEmpty(arg.TypeAttr)) |
|
|
|
|
|
|
|
|
else if (!string.IsNullOrEmpty(arg.TypeAttr)) |
|
|
{ |
|
|
{ |
|
|
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); |
|
|
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); |
|
|
} |
|
|
} |
|
@@ -222,6 +184,79 @@ namespace Tensorflow |
|
|
}); |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private void SetAttrs(string op_type_name, |
|
|
|
|
|
ArgDef input_arg, |
|
|
|
|
|
OpDef op_def, |
|
|
|
|
|
Dictionary<string, object> attrs, |
|
|
|
|
|
Dictionary<string, object> inferred_from, |
|
|
|
|
|
List<TF_DataType> types, |
|
|
|
|
|
List<TF_DataType> base_types, |
|
|
|
|
|
List<TF_DataType> input_types, |
|
|
|
|
|
dynamic values) |
|
|
|
|
|
{ |
|
|
|
|
|
var input_name = input_arg.Name; |
|
|
|
|
|
|
|
|
|
|
|
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
if (attrs.ContainsKey(input_arg.NumberAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.NumberAttr] = (values as Tensor[]).Length; |
|
|
|
|
|
inferred_from[input_arg.NumberAttr] = input_name; |
|
|
|
|
|
var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr); |
|
|
|
|
|
if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum) |
|
|
|
|
|
throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " + |
|
|
|
|
|
$"than minimum length {num_attr.Minimum}"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// All tensors must have the same base type. |
|
|
|
|
|
if (input_arg.Type != DataType.DtInvalid) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeAttr] = base_types[0]; |
|
|
|
|
|
inferred_from[input_arg.TypeAttr] = input_name; |
|
|
|
|
|
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
var attr_value = base_types[0]; |
|
|
|
|
|
if (attrs.ContainsKey(input_arg.TypeAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeAttr] = attr_value; |
|
|
|
|
|
inferred_from[input_arg.TypeAttr] = input_name; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
var attr_value = base_types; |
|
|
|
|
|
if (attrs.ContainsKey(input_arg.TypeListAttr)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
attrs[input_arg.TypeListAttr] = attr_value; |
|
|
|
|
|
inferred_from[input_arg.TypeListAttr] = input_name; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (input_arg.IsRef) |
|
|
|
|
|
input_types.AddRange(types); |
|
|
|
|
|
else |
|
|
|
|
|
input_types.AddRange(base_types); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
public DataType _MakeType(TF_DataType v, AttrDef attr_def) |
|
|
public DataType _MakeType(TF_DataType v, AttrDef attr_def) |
|
|
{ |
|
|
{ |
|
|
return v.as_base_dtype().as_datatype_enum(); |
|
|
return v.as_base_dtype().as_datatype_enum(); |
|
@@ -231,6 +266,13 @@ namespace Tensorflow |
|
|
{ |
|
|
{ |
|
|
var attr_value = new AttrValue(); |
|
|
var attr_value = new AttrValue(); |
|
|
|
|
|
|
|
|
|
|
|
if (attr_def.Type.StartsWith("list(")) |
|
|
|
|
|
{ |
|
|
|
|
|
if (attr_def.HasMinimum) |
|
|
|
|
|
; |
|
|
|
|
|
attr_value.List = new AttrValue.Types.ListValue(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
switch (attr_def.Type) |
|
|
switch (attr_def.Type) |
|
|
{ |
|
|
{ |
|
|
case "string": |
|
|
case "string": |
|
@@ -240,8 +282,6 @@ namespace Tensorflow |
|
|
attr_value.Type = _MakeType((TF_DataType)value, attr_def); |
|
|
attr_value.Type = _MakeType((TF_DataType)value, attr_def); |
|
|
break; |
|
|
break; |
|
|
case "list(type)": |
|
|
case "list(type)": |
|
|
if (attr_value.List == null) |
|
|
|
|
|
attr_value.List = new AttrValue.Types.ListValue(); |
|
|
|
|
|
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); |
|
|
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); |
|
|
break; |
|
|
break; |
|
|
case "bool": |
|
|
case "bool": |
|
|