|
|
@@ -24,7 +24,7 @@ namespace Tensorflow |
|
|
|
_ops[op_def.Name] = op_def; |
|
|
|
} |
|
|
|
|
|
|
|
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null) |
|
|
|
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) |
|
|
|
{ |
|
|
|
var op_def = _ops[op_type_name]; |
|
|
|
|
|
|
@@ -46,9 +46,30 @@ namespace Tensorflow |
|
|
|
var key = attr_def.Name; |
|
|
|
} |
|
|
|
|
|
|
|
foreach(var input_arg in op_def.InputArg) |
|
|
|
var attrs = new Dictionary<string, object>(); |
|
|
|
var inputs = new List<Tensor>(); |
|
|
|
var input_types = new List<DataType>(); |
|
|
|
|
|
|
|
foreach (var attr in op_def.Attr) |
|
|
|
{ |
|
|
|
if (keywords.ContainsKey(attr.Name)) |
|
|
|
{ |
|
|
|
attrs[attr.Name] = keywords[attr.Name]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
foreach (var input_arg in op_def.InputArg) |
|
|
|
{ |
|
|
|
var input_name = input_arg.Name; |
|
|
|
if (keywords.ContainsKey(input_name)) |
|
|
|
{ |
|
|
|
inputs.Add(keywords[input_name] as Tensor); |
|
|
|
} |
|
|
|
|
|
|
|
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
{ |
|
|
|
attrs[input_arg.TypeAttr] = DataType.DtFloat; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var attr_protos = new Dictionary<string, AttrValue>(); |
|
|
@@ -60,7 +81,7 @@ namespace Tensorflow |
|
|
|
switch (attr_def.Type) |
|
|
|
{ |
|
|
|
case "type": |
|
|
|
attr_value.Type = dtype.Value; |
|
|
|
attr_value.Type = (DataType)keywords["dtype"]; |
|
|
|
break; |
|
|
|
case "shape": |
|
|
|
attr_value.Shape = new TensorShapeProto(); |
|
|
@@ -84,9 +105,9 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var op = g.create_op(op_type_name, null, output_types.ToArray(), |
|
|
|
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), |
|
|
|
name: scope, |
|
|
|
input_types: new DataType[] { }, |
|
|
|
input_types: input_types.ToArray(), |
|
|
|
attrs: attr_protos, |
|
|
|
op_def: op_def); |
|
|
|
|
|
|
|