diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 9ea40816..2d0d7d28 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -138,6 +138,9 @@ namespace Tensorflow.Gradients [RegisterNoGradient("GreaterEqual")] public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; + [RegisterNoGradient("OnesLike")] + public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null; + [RegisterNoGradient("ZerosLike")] public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 1801d69a..625d76a1 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -274,7 +274,7 @@ namespace Tensorflow { if (elem is EagerTensor eager_tensor) { - if(switch_to_graph) + if (switch_to_graph) elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString())); else elems_as_tensors.Add(eager_tensor); @@ -366,8 +366,30 @@ namespace Tensorflow /// /// /// - public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) - => ones_like_impl(tensor, dtype, name, optimize); + public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope => + { + name = scope; + tensor = ops.convert_to_tensor(tensor, name: "tensor"); + + // is_fully_defined return unexpected value. + if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) + { + + } + + if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) + { + throw new NotImplementedException("ones_like"); + // return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); + } + else + { + return gen_array_ops.ones_like(tensor, name: name); + } + }); + } public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) => gen_array_ops.reshape(tensor, shape, name: name); @@ -888,7 +910,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "transpose", new { a }), scope => { var a_tensor = ops.convert_to_tensor(a); - if(perm == null) + if (perm == null) { var rank = a_tensor.rank; perm = range(0, rank).OrderByDescending(x => x).ToArray(); @@ -950,7 +972,9 @@ namespace Tensorflow => tf.Context.RunInAutoMode2( () => tf.OpDefLib._apply_op_helper("Slice", name, new { - input, begin, size + input, + begin, + size }).output, () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Slice", name, @@ -966,8 +990,8 @@ namespace Tensorflow tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); }, new Tensors(input, begin, size)); - - public static Tensor stack(object values, int axis = 0, string name = "stack") + + public static Tensor stack(object values, int axis = 0, string name = "stack") { if (axis == 0) // If the input is a constant list, it can be converted to a constant op diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index a2db25d9..e29227c4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -591,6 +591,15 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor ones_like(Tensor x, string name = null) + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "OnesLike", name, + null, + x).FirstOrDefault(), + x); + public static Tensor zeros_like(Tensor x, string name = null) => tf.Context.RunInAutoMode(() => tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()