Browse Source

ones_like fix

tags/yolov3
MPnoy 4 years ago
parent
commit
bee3a10ccb
4 changed files with 43 additions and 23 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +31
    -22
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +9
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  4. +0
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

+ 3
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -138,6 +138,9 @@ namespace Tensorflow.Gradients
[RegisterNoGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("OnesLike")]
public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("ZerosLike")]
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;



+ 31
- 22
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -274,7 +274,7 @@ namespace Tensorflow
{
if (elem is EagerTensor eager_tensor)
{
if(switch_to_graph)
if (switch_to_graph)
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString()));
else
elems_as_tensors.Add(eager_tensor);
@@ -366,8 +366,30 @@ namespace Tensorflow
/// <param name="name"></param>
/// <param name="optimize"></param>
/// <returns></returns>
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope =>
{
name = scope;
tensor = ops.convert_to_tensor(tensor, name: "tensor");

// is_fully_defined return unexpected value.
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{

}

if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
{
throw new NotImplementedException("ones_like");
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
}
else
{
return gen_array_ops.ones_like(tensor, name: name);
}
});
}

public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, name: name);
@@ -378,21 +400,6 @@ namespace Tensorflow
public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, name: name);

private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
{
name = scope;
var tensor1 = ops.convert_to_tensor(tensor, name: "tensor");
var ones_shape = shape_internal(tensor1, optimize: optimize);
if (dtype == TF_DataType.DtInvalid)
dtype = tensor1.dtype;
var ret = ones(ones_shape, dtype: dtype, name: name);
ret.shape = tensor1.shape;
return ret;
});
}

public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
@@ -891,7 +898,7 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_tensor = ops.convert_to_tensor(a);
if(perm == null)
if (perm == null)
{
var rank = a_tensor.rank;
perm = range(0, rank).OrderByDescending(x => x).ToArray();
@@ -953,7 +960,9 @@ namespace Tensorflow
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Slice", name, new
{
input, begin, size
input,
begin,
size
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Slice", name,
@@ -969,8 +978,8 @@ namespace Tensorflow
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
},
new Tensors(input, begin, size));
public static Tensor stack(object values, int axis = 0, string name = "stack")
public static Tensor stack(object values, int axis = 0, string name = "stack")
{
if (axis == 0)
// If the input is a constant list, it can be converted to a constant op


+ 9
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -591,6 +591,15 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor ones_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OnesLike", name,
null,
x).FirstOrDefault(),
x);

public static Tensor zeros_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()


+ 0
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs View File

@@ -132,7 +132,6 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
}

#region ones/zeros like
[Ignore]
[TestMethod]
public void TestOnesLike()
{


Loading…
Cancel
Save