Browse Source

fix input list into op

tags/v0.20
Oceania2018 5 years ago
parent
commit
afaf0c8072
1 changed files with 36 additions and 13 deletions
  1. +36
    -13
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs

+ 36
- 13
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow.Eager
/// </summary>
public class pywrap_tfe_src
{
static int kFastPathExecuteInputStartIndex = 0;
public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
string device_name,
string opName,
@@ -28,7 +29,7 @@ namespace Tensorflow.Eager

// Set non-inferred attrs, including setting defaults if the attr is passed in
// as None.
for (int i = op_def.InputArg.Count; i < args_size; i += 2)
for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
{
var attr_name = args[i].ToString();
var attr_value = args[i + 1];
@@ -38,20 +39,39 @@ namespace Tensorflow.Eager
if(attr_name == attr.Name)
{
SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
status.Check(true);
break;
}
}
}

c_api.TFE_OpSetDevice(op, device_name, status);
status.Check(true);

// Add inferred attrs and inputs.
for (int i = 0; i < op_def.InputArg.Count; i++)
{
var input_arg = op_def.InputArg[i];
int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length;
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0);
attr_list_sizes[input_arg.NumberAttr] = 0;
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
attr_list_sizes[input_arg.NumberAttr] = len;

if (len > 0)
{
var fast_input_array = (object[])args[i];
// First item adds the type attr.
if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status))
return null;

for (var j = 1; j < len; j++)
{
// Since the list is homogeneous, we don't need to re-add the attr.
if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status))
return null;
}
}
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
@@ -60,14 +80,7 @@ namespace Tensorflow.Eager
else
{
// The item is a single item.
switch (args[i])
{
case Tensor inputTensor:
AddInputToOp(inputTensor, true, input_arg, op, status);
break;
default:
throw new NotImplementedException("");
}
AddInputToOp(args[i], true, input_arg, op, status);
}
}

@@ -106,13 +119,23 @@ namespace Tensorflow.Eager
/// <param name="op"></param>
/// <param name="status"></param>
/// <returns></returns>
private static bool AddInputToOp(Tensor input,
private static bool AddInputToOp(object inputs,
bool add_type_attr,
ArgDef input_arg,
IntPtr op,
Status status)
{
var input_handle = c_api.TFE_NewTensorHandle(input, status);
IntPtr input_handle = IntPtr.Zero;

switch (inputs)
{
case Tensor input:
input_handle = c_api.TFE_NewTensorHandle(input, status);
break;
default:
throw new NotImplementedException("");
}


if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{


Loading…
Cancel
Save