diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index ece3baf8..5e0f83e6 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -812,6 +812,9 @@ namespace Tensorflow return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); } + public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) + => gen_array_ops.slice(input, begin, size, name: name); + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) => gen_array_ops.slice(input, begin, size, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index d252e077..51f1c1c9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using System.Linq; using Tensorflow.Contexts; using static Tensorflow.Binding; @@ -448,15 +449,34 @@ namespace Tensorflow return _op.outputs[0]; } - /// - /// Return a slice from 'input' - /// - /// - /// - /// - /// - /// - public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) + public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) + { + if (tf.executing_eagerly()) + { + var result = slice_eager_fallback(input, begin, size, name, tf.Context); + return result; + } + + var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size }); + return _op.outputs[0]; + } + + private static Tensor slice_eager_fallback(Tensor inputs, Tensor[] begin, Tensor[] size, string name, Context ctx) + { + var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new[] { inputs }); + var (_attr_Tidx, _inputs_Index) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { begin, size }); + var _inputs_flat = input.concat(_inputs_Index); + var _attrs = new object[] { "T", _attr_T, "Index", _attr_Tidx }; + + var results = tf.Runner.Execute(ctx, "Slice", 1, _inputs_flat, _attrs, name: name); + if (tf.Runner.MustRecordGradient()) + { + tf.Runner.RecordGradient("Slice", _inputs_flat, _attrs, results); + } + return results[0]; + } + + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) { var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size }); return _op.outputs[0]; @@ -605,12 +625,6 @@ namespace Tensorflow "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), shape, begin, end, strides, dy); - public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size }); - return _op.outputs[0]; - } - /// /// Removes dimensions of size 1 from the shape of a tensor. /// Given a tensor `input`, this operation returns a tensor of the same type with