diff --git a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
index bf2ecb62..792e945d 100644
--- a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
+++ b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
@@ -10,6 +10,7 @@ namespace Tensorflow.Eager
///
public class pywrap_tfe_src
{
+ static int kFastPathExecuteInputStartIndex = 0;
public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
string device_name,
string opName,
@@ -28,7 +29,7 @@ namespace Tensorflow.Eager
// Set non-inferred attrs, including setting defaults if the attr is passed in
// as None.
- for (int i = op_def.InputArg.Count; i < args_size; i += 2)
+ for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
{
var attr_name = args[i].ToString();
var attr_value = args[i + 1];
@@ -38,20 +39,39 @@ namespace Tensorflow.Eager
if(attr_name == attr.Name)
{
SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
+ status.Check(true);
break;
}
}
}
c_api.TFE_OpSetDevice(op, device_name, status);
+ status.Check(true);
+ // Add inferred attrs and inputs.
for (int i = 0; i < op_def.InputArg.Count; i++)
{
var input_arg = op_def.InputArg[i];
+ int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length;
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
- c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0);
- attr_list_sizes[input_arg.NumberAttr] = 0;
+ c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
+ attr_list_sizes[input_arg.NumberAttr] = len;
+
+ if (len > 0)
+ {
+ var fast_input_array = (object[])args[i];
+ // First item adds the type attr.
+ if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status))
+ return null;
+
+ for (var j = 1; j < len; j++)
+ {
+ // Since the list is homogeneous, we don't need to re-add the attr.
+ if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status))
+ return null;
+ }
+ }
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
@@ -60,14 +80,7 @@ namespace Tensorflow.Eager
else
{
// The item is a single item.
- switch (args[i])
- {
- case Tensor inputTensor:
- AddInputToOp(inputTensor, true, input_arg, op, status);
- break;
- default:
- throw new NotImplementedException("");
- }
+ AddInputToOp(args[i], true, input_arg, op, status);
}
}
@@ -106,13 +119,23 @@ namespace Tensorflow.Eager
///
///
///
- private static bool AddInputToOp(Tensor input,
+ private static bool AddInputToOp(object inputs,
bool add_type_attr,
ArgDef input_arg,
IntPtr op,
Status status)
{
- var input_handle = c_api.TFE_NewTensorHandle(input, status);
+ IntPtr input_handle = IntPtr.Zero;
+
+ switch (inputs)
+ {
+ case Tensor input:
+ input_handle = c_api.TFE_NewTensorHandle(input, status);
+ break;
+ default:
+ throw new NotImplementedException("");
+ }
+
if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{