From afaf0c80729abd0f01dba42ecaad800c8a08985a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 10 Mar 2020 23:16:35 -0500 Subject: [PATCH] fix input list into op --- .../Eager/pywrap_tfe_src.cs | 49 ++++++++++++++----- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs index bf2ecb62..792e945d 100644 --- a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs +++ b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs @@ -10,6 +10,7 @@ namespace Tensorflow.Eager /// public class pywrap_tfe_src { + static int kFastPathExecuteInputStartIndex = 0; public static EagerTensor TFE_Py_FastPathExecute(Context ctx, string device_name, string opName, @@ -28,7 +29,7 @@ namespace Tensorflow.Eager // Set non-inferred attrs, including setting defaults if the attr is passed in // as None. - for (int i = op_def.InputArg.Count; i < args_size; i += 2) + for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) { var attr_name = args[i].ToString(); var attr_value = args[i + 1]; @@ -38,20 +39,39 @@ namespace Tensorflow.Eager if(attr_name == attr.Name) { SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); + status.Check(true); break; } } } c_api.TFE_OpSetDevice(op, device_name, status); + status.Check(true); + // Add inferred attrs and inputs. for (int i = 0; i < op_def.InputArg.Count; i++) { var input_arg = op_def.InputArg[i]; + int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { - c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0); - attr_list_sizes[input_arg.NumberAttr] = 0; + c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); + attr_list_sizes[input_arg.NumberAttr] = len; + + if (len > 0) + { + var fast_input_array = (object[])args[i]; + // First item adds the type attr. + if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status)) + return null; + + for (var j = 1; j < len; j++) + { + // Since the list is homogeneous, we don't need to re-add the attr. + if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status)) + return null; + } + } } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { @@ -60,14 +80,7 @@ namespace Tensorflow.Eager else { // The item is a single item. - switch (args[i]) - { - case Tensor inputTensor: - AddInputToOp(inputTensor, true, input_arg, op, status); - break; - default: - throw new NotImplementedException(""); - } + AddInputToOp(args[i], true, input_arg, op, status); } } @@ -106,13 +119,23 @@ namespace Tensorflow.Eager /// /// /// - private static bool AddInputToOp(Tensor input, + private static bool AddInputToOp(object inputs, bool add_type_attr, ArgDef input_arg, IntPtr op, Status status) { - var input_handle = c_api.TFE_NewTensorHandle(input, status); + IntPtr input_handle = IntPtr.Zero; + + switch (inputs) + { + case Tensor input: + input_handle = c_api.TFE_NewTensorHandle(input, status); + break; + default: + throw new NotImplementedException(""); + } + if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) {