Browse Source

Merge pull request #736 from MPnoy/ones_like-fix

ones_like fix
tags/yolov3
Haiping GitHub 4 years ago
parent
commit
47f953b94f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 7 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +31
    -7
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +9
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs

+ 3
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -138,6 +138,9 @@ namespace Tensorflow.Gradients
[RegisterNoGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("OnesLike")]
public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("ZerosLike")]
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;



+ 31
- 7
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -274,7 +274,7 @@ namespace Tensorflow
{
if (elem is EagerTensor eager_tensor)
{
if(switch_to_graph)
if (switch_to_graph)
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString()));
else
elems_as_tensors.Add(eager_tensor);
@@ -366,8 +366,30 @@ namespace Tensorflow
/// <param name="name"></param>
/// <param name="optimize"></param>
/// <returns></returns>
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope =>
{
name = scope;
tensor = ops.convert_to_tensor(tensor, name: "tensor");

// is_fully_defined return unexpected value.
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{

}

if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
{
throw new NotImplementedException("ones_like");
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
}
else
{
return gen_array_ops.ones_like(tensor, name: name);
}
});
}

public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, name: name);
@@ -888,7 +910,7 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_tensor = ops.convert_to_tensor(a);
if(perm == null)
if (perm == null)
{
var rank = a_tensor.rank;
perm = range(0, rank).OrderByDescending(x => x).ToArray();
@@ -950,7 +972,9 @@ namespace Tensorflow
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Slice", name, new
{
input, begin, size
input,
begin,
size
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Slice", name,
@@ -966,8 +990,8 @@ namespace Tensorflow
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
},
new Tensors(input, begin, size));
public static Tensor stack(object values, int axis = 0, string name = "stack")
public static Tensor stack(object values, int axis = 0, string name = "stack")
{
if (axis == 0)
// If the input is a constant list, it can be converted to a constant op


+ 9
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -591,6 +591,15 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor ones_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OnesLike", name,
null,
x).FirstOrDefault(),
x);

public static Tensor zeros_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()


Loading…
Cancel
Save