|
|
@@ -274,7 +274,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
if (elem is EagerTensor eager_tensor) |
|
|
|
{ |
|
|
|
if(switch_to_graph) |
|
|
|
if (switch_to_graph) |
|
|
|
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString())); |
|
|
|
else |
|
|
|
elems_as_tensors.Add(eager_tensor); |
|
|
@@ -366,8 +366,30 @@ namespace Tensorflow |
|
|
|
/// <param name="name"></param> |
|
|
|
/// <param name="optimize"></param> |
|
|
|
/// <returns></returns> |
|
|
|
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) |
|
|
|
=> ones_like_impl(tensor, dtype, name, optimize); |
|
|
|
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) |
|
|
|
{ |
|
|
|
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope => |
|
|
|
{ |
|
|
|
name = scope; |
|
|
|
tensor = ops.convert_to_tensor(tensor, name: "tensor"); |
|
|
|
|
|
|
|
// is_fully_defined return unexpected value. |
|
|
|
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) |
|
|
|
{ |
|
|
|
throw new NotImplementedException("ones_like"); |
|
|
|
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
return gen_array_ops.ones_like(tensor, name: name); |
|
|
|
} |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) |
|
|
|
=> gen_array_ops.reshape(tensor, shape, name: name); |
|
|
@@ -888,7 +910,7 @@ namespace Tensorflow |
|
|
|
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => |
|
|
|
{ |
|
|
|
var a_tensor = ops.convert_to_tensor(a); |
|
|
|
if(perm == null) |
|
|
|
if (perm == null) |
|
|
|
{ |
|
|
|
var rank = a_tensor.rank; |
|
|
|
perm = range(0, rank).OrderByDescending(x => x).ToArray(); |
|
|
@@ -950,7 +972,9 @@ namespace Tensorflow |
|
|
|
=> tf.Context.RunInAutoMode2( |
|
|
|
() => tf.OpDefLib._apply_op_helper("Slice", name, new |
|
|
|
{ |
|
|
|
input, begin, size |
|
|
|
input, |
|
|
|
begin, |
|
|
|
size |
|
|
|
}).output, |
|
|
|
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, |
|
|
|
"Slice", name, |
|
|
@@ -966,8 +990,8 @@ namespace Tensorflow |
|
|
|
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); |
|
|
|
}, |
|
|
|
new Tensors(input, begin, size)); |
|
|
|
|
|
|
|
public static Tensor stack(object values, int axis = 0, string name = "stack") |
|
|
|
|
|
|
|
public static Tensor stack(object values, int axis = 0, string name = "stack") |
|
|
|
{ |
|
|
|
if (axis == 0) |
|
|
|
// If the input is a constant list, it can be converted to a constant op |
|
|
|