|
|
@@ -10,6 +10,7 @@ namespace Tensorflow.Eager |
|
|
|
/// </summary> |
|
|
|
public class pywrap_tfe_src |
|
|
|
{ |
|
|
|
static int kFastPathExecuteInputStartIndex = 0; |
|
|
|
public static EagerTensor TFE_Py_FastPathExecute(Context ctx, |
|
|
|
string device_name, |
|
|
|
string opName, |
|
|
@@ -28,7 +29,7 @@ namespace Tensorflow.Eager |
|
|
|
|
|
|
|
// Set non-inferred attrs, including setting defaults if the attr is passed in |
|
|
|
// as None. |
|
|
|
for (int i = op_def.InputArg.Count; i < args_size; i += 2) |
|
|
|
for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) |
|
|
|
{ |
|
|
|
var attr_name = args[i].ToString(); |
|
|
|
var attr_value = args[i + 1]; |
|
|
@@ -38,20 +39,39 @@ namespace Tensorflow.Eager |
|
|
|
if(attr_name == attr.Name) |
|
|
|
{ |
|
|
|
SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); |
|
|
|
status.Check(true); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
c_api.TFE_OpSetDevice(op, device_name, status); |
|
|
|
status.Check(true); |
|
|
|
|
|
|
|
// Add inferred attrs and inputs. |
|
|
|
for (int i = 0; i < op_def.InputArg.Count; i++) |
|
|
|
{ |
|
|
|
var input_arg = op_def.InputArg[i]; |
|
|
|
int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; |
|
|
|
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) |
|
|
|
{ |
|
|
|
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0); |
|
|
|
attr_list_sizes[input_arg.NumberAttr] = 0; |
|
|
|
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); |
|
|
|
attr_list_sizes[input_arg.NumberAttr] = len; |
|
|
|
|
|
|
|
if (len > 0) |
|
|
|
{ |
|
|
|
var fast_input_array = (object[])args[i]; |
|
|
|
// First item adds the type attr. |
|
|
|
if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status)) |
|
|
|
return null; |
|
|
|
|
|
|
|
for (var j = 1; j < len; j++) |
|
|
|
{ |
|
|
|
// Since the list is homogeneous, we don't need to re-add the attr. |
|
|
|
if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status)) |
|
|
|
return null; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) |
|
|
|
{ |
|
|
@@ -60,14 +80,7 @@ namespace Tensorflow.Eager |
|
|
|
else |
|
|
|
{ |
|
|
|
// The item is a single item. |
|
|
|
switch (args[i]) |
|
|
|
{ |
|
|
|
case Tensor inputTensor: |
|
|
|
AddInputToOp(inputTensor, true, input_arg, op, status); |
|
|
|
break; |
|
|
|
default: |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
AddInputToOp(args[i], true, input_arg, op, status); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -106,13 +119,23 @@ namespace Tensorflow.Eager |
|
|
|
/// <param name="op"></param> |
|
|
|
/// <param name="status"></param> |
|
|
|
/// <returns></returns> |
|
|
|
private static bool AddInputToOp(Tensor input, |
|
|
|
private static bool AddInputToOp(object inputs, |
|
|
|
bool add_type_attr, |
|
|
|
ArgDef input_arg, |
|
|
|
IntPtr op, |
|
|
|
Status status) |
|
|
|
{ |
|
|
|
var input_handle = c_api.TFE_NewTensorHandle(input, status); |
|
|
|
IntPtr input_handle = IntPtr.Zero; |
|
|
|
|
|
|
|
switch (inputs) |
|
|
|
{ |
|
|
|
case Tensor input: |
|
|
|
input_handle = c_api.TFE_NewTensorHandle(input, status); |
|
|
|
break; |
|
|
|
default: |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) |
|
|
|
{ |
|
|
|