@@ -84,32 +84,25 @@ namespace Tensorflow | |||
dtype = dtype.as_base_dtype(); | |||
values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef); | |||
inputs.AddRange(values as Tensor[]); | |||
} | |||
else | |||
{ | |||
if (!(values is Tensor)) | |||
if (keywords[input_name] is Tensor) | |||
{ | |||
keywords[input_name] = constant_op.constant(values, input_name); | |||
} | |||
if (keywords[input_name] is Tensor value) | |||
else | |||
{ | |||
if (keywords.ContainsKey(input_name)) | |||
{ | |||
inputs.Add(value); | |||
} | |||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
{ | |||
attrs[input_arg.TypeAttr] = value.dtype; | |||
} | |||
keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name); | |||
} | |||
values = new Tensor[] { value }; | |||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
{ | |||
attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype; | |||
} | |||
values = new Tensor[] { keywords[input_name] as Tensor }; | |||
} | |||
inputs.AddRange(values as Tensor[]); | |||
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||
input_types.AddRange(base_types); | |||
} | |||
@@ -310,11 +310,11 @@ namespace Tensorflow | |||
}; | |||
} | |||
public static T[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | |||
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | |||
string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||
bool as_ref = false) | |||
{ | |||
var ret = new List<T>(); | |||
var ret = new List<Tensor>(); | |||
foreach((int i, T value) in Python.enumerate(values)) | |||
{ | |||
@@ -325,16 +325,16 @@ namespace Tensorflow | |||
return ret.ToArray(); | |||
} | |||
public static T internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid, | |||
public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid, | |||
string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||
bool as_ref = false) | |||
{ | |||
switch (typeof(T).Name) | |||
{ | |||
case "Tensor": | |||
return value; | |||
return value as Tensor; | |||
default: | |||
throw new NotImplementedException("internal_convert_to_tensor"); | |||
return constant_op.constant(np.array(value), name); | |||
} | |||
} | |||
} | |||