Browse Source

fix internal_convert_n_to_tensor return type.

tags/v0.8.0
haiping008 6 years ago
parent
commit
1c5731faf5
2 changed files with 14 additions and 21 deletions
  1. +9
    -16
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  2. +5
    -5
      src/TensorFlowNET.Core/ops.py.cs

+ 9
- 16
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -84,32 +84,25 @@ namespace Tensorflow
dtype = dtype.as_base_dtype();

values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef);
inputs.AddRange(values as Tensor[]);
}
else
{
if (!(values is Tensor))
if (keywords[input_name] is Tensor)
{
keywords[input_name] = constant_op.constant(values, input_name);
}

if (keywords[input_name] is Tensor value)
else
{
if (keywords.ContainsKey(input_name))
{
inputs.Add(value);
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = value.dtype;
}
keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name);
}

values = new Tensor[] { value };
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype;
}
values = new Tensor[] { keywords[input_name] as Tensor };
}

inputs.AddRange(values as Tensor[]);
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
input_types.AddRange(base_types);
}


+ 5
- 5
src/TensorFlowNET.Core/ops.py.cs View File

@@ -310,11 +310,11 @@ namespace Tensorflow
};
}

public static T[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
bool as_ref = false)
{
var ret = new List<T>();
var ret = new List<Tensor>();

foreach((int i, T value) in Python.enumerate(values))
{
@@ -325,16 +325,16 @@ namespace Tensorflow
return ret.ToArray();
}

public static T internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
bool as_ref = false)
{
switch (typeof(T).Name)
{
case "Tensor":
return value;
return value as Tensor;
default:
throw new NotImplementedException("internal_convert_to_tensor");
return constant_op.constant(np.array(value), name);
}
}
}


Loading…
Cancel
Save