Browse Source

fix array_ops.slice #646

tags/v0.30
Oceania2018 4 years ago
parent
commit
b59f5a7290
2 changed files with 32 additions and 15 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  2. +29
    -15
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs

+ 3
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -812,6 +812,9 @@ namespace Tensorflow
return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
}

public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);

public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);



+ 29
- 15
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Contexts;
using static Tensorflow.Binding;
@@ -448,15 +449,34 @@ namespace Tensorflow
return _op.outputs[0];
}

/// <summary>
/// Return a slice from 'input'
/// </summary>
/// <param name="input"></param>
/// <param name="begin"></param>
/// <param name="size"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)
{
if (tf.executing_eagerly())
{
var result = slice_eager_fallback(input, begin, size, name, tf.Context);
return result;
}

var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0];
}

private static Tensor slice_eager_fallback(Tensor inputs, Tensor[] begin, Tensor[] size, string name, Context ctx)
{
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new[] { inputs });
var (_attr_Tidx, _inputs_Index) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { begin, size });
var _inputs_flat = input.concat(_inputs_Index);
var _attrs = new object[] { "T", _attr_T, "Index", _attr_Tidx };

var results = tf.Runner.Execute(ctx, "Slice", 1, _inputs_flat, _attrs, name: name);
if (tf.Runner.MustRecordGradient())
{
tf.Runner.RecordGradient("Slice", _inputs_flat, _attrs, results);
}
return results[0];
}

public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0];
@@ -605,12 +625,6 @@ namespace Tensorflow
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
shape, begin, end, strides, dy);

public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0];
}

/// <summary>
/// Removes dimensions of size 1 from the shape of a tensor.
/// Given a tensor `input`, this operation returns a tensor of the same type with


Loading…
Cancel
Save