Browse Source

fix input list into op

tags/v0.20
Oceania2018 5 years ago
parent
commit
afaf0c8072
1 changed files with 36 additions and 13 deletions
  1. +36
    -13
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs

+ 36
- 13
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow.Eager
/// </summary> /// </summary>
public class pywrap_tfe_src public class pywrap_tfe_src
{ {
static int kFastPathExecuteInputStartIndex = 0;
public static EagerTensor TFE_Py_FastPathExecute(Context ctx, public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
string device_name, string device_name,
string opName, string opName,
@@ -28,7 +29,7 @@ namespace Tensorflow.Eager


// Set non-inferred attrs, including setting defaults if the attr is passed in // Set non-inferred attrs, including setting defaults if the attr is passed in
// as None. // as None.
for (int i = op_def.InputArg.Count; i < args_size; i += 2)
for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
{ {
var attr_name = args[i].ToString(); var attr_name = args[i].ToString();
var attr_value = args[i + 1]; var attr_value = args[i + 1];
@@ -38,20 +39,39 @@ namespace Tensorflow.Eager
if(attr_name == attr.Name) if(attr_name == attr.Name)
{ {
SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
status.Check(true);
break; break;
} }
} }
} }


c_api.TFE_OpSetDevice(op, device_name, status); c_api.TFE_OpSetDevice(op, device_name, status);
status.Check(true);


// Add inferred attrs and inputs.
for (int i = 0; i < op_def.InputArg.Count; i++) for (int i = 0; i < op_def.InputArg.Count; i++)
{ {
var input_arg = op_def.InputArg[i]; var input_arg = op_def.InputArg[i];
int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length;
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{ {
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0);
attr_list_sizes[input_arg.NumberAttr] = 0;
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
attr_list_sizes[input_arg.NumberAttr] = len;

if (len > 0)
{
var fast_input_array = (object[])args[i];
// First item adds the type attr.
if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status))
return null;

for (var j = 1; j < len; j++)
{
// Since the list is homogeneous, we don't need to re-add the attr.
if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status))
return null;
}
}
} }
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{ {
@@ -60,14 +80,7 @@ namespace Tensorflow.Eager
else else
{ {
// The item is a single item. // The item is a single item.
switch (args[i])
{
case Tensor inputTensor:
AddInputToOp(inputTensor, true, input_arg, op, status);
break;
default:
throw new NotImplementedException("");
}
AddInputToOp(args[i], true, input_arg, op, status);
} }
} }


@@ -106,13 +119,23 @@ namespace Tensorflow.Eager
/// <param name="op"></param> /// <param name="op"></param>
/// <param name="status"></param> /// <param name="status"></param>
/// <returns></returns> /// <returns></returns>
private static bool AddInputToOp(Tensor input,
private static bool AddInputToOp(object inputs,
bool add_type_attr, bool add_type_attr,
ArgDef input_arg, ArgDef input_arg,
IntPtr op, IntPtr op,
Status status) Status status)
{ {
var input_handle = c_api.TFE_NewTensorHandle(input, status);
IntPtr input_handle = IntPtr.Zero;

switch (inputs)
{
case Tensor input:
input_handle = c_api.TFE_NewTensorHandle(input, status);
break;
default:
throw new NotImplementedException("");
}



if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{ {


Loading…
Cancel
Save