diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index 4308ff91..7e4c3395 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -27,9 +27,6 @@ namespace Tensorflow public void add_to_collections(List names, T value) => get_default_graph().add_to_collections(names, value); - public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) - => state_ops.assign(@ref, value, validate_shape, use_locking, name); - public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) => state_ops.assign(@ref, value, validate_shape, use_locking, name); diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 709b478d..3ba50a62 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -91,7 +91,7 @@ namespace Tensorflow.Contexts context_switches.Pop(); } - [DebuggerStepThrough] + // [DebuggerStepThrough] public T RunInAutoMode(Func graphAction, Func eagerAction, params Tensor[] tensors) { var shouldRunInEager = executing_eagerly() @@ -115,7 +115,7 @@ namespace Tensorflow.Contexts } } - [DebuggerStepThrough] + // [DebuggerStepThrough] public Tensors RunInAutoMode2(Func graphAction, Func eagerAction, Action recordGradient, diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 814c6fc3..3087639b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -593,6 +593,30 @@ namespace Tensorflow "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), input, begin, end, strides); + public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new + { + input, begin, end, strides, value, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask + }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ResourceStridedSliceAssign", name, + null, + input, begin, end, strides, value, + "begin_mask", begin_mask, + "end_mask", end_mask, + "ellipsis_mask", ellipsis_mask, + "new_axis_mask", new_axis_mask, + "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), + input, begin, end, strides, value); + public static Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides, int begin_mask = 0, int end_mask = 0, diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 09773a71..861f99ed 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -34,43 +34,6 @@ namespace Tensorflow name: name); } - /// - /// - /// - /// - /// - /// - /// - /// - /// If `read_value` is `True`, this method will return the new value of the - /// variable after the assignment has completed.Otherwise, when in graph mode - /// it will return the `Operation` that does the assignment, and when in eager - /// mode it will return `None`. - /// - public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true) - { - var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype); - self.assert_is_compatible_with(value_tensor); - var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name); - if (read_value) - { - return self._lazy_read(assign_op); - } - - return assign_op; - } - - public static Operation _lazy_read(this Tensor self, Operation op) - { - variable_accessed(self); - throw new NotImplementedException(); - } - - public static void variable_accessed(this Tensor variable) - { - throw new NotImplementedException(); - } - public static bool is_resource_variable(IVariableV1 var) { return var is ResourceVariable; diff --git a/src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs b/src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs new file mode 100644 index 00000000..c6404c3f --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class ParsedSliceArgs + { + public int[] Begin { get; set; } + public Tensor PackedBegin { get; set; } + public int[] End { get; set; } + public Tensor PackedEnd { get; set; } + public int[] Strides { get; set; } + public Tensor PackedStrides { get; set; } + public int BeginMask { get; set; } + public int EndMask { get; set; } + public int ShrinkAxisMask { get; set; } + public int NewAxisMask { get; set; } + public int EllipsisMask { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs new file mode 100644 index 00000000..eaaae613 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs @@ -0,0 +1,24 @@ +using NumSharp; + +namespace Tensorflow +{ + public partial class Tensor + { + /// + /// Used to keep the original variable when slicing + /// + public ResourceVariable OriginalVar { get; set; } + public ParsedSliceArgs OriginalVarSlice { get; set; } + + public ResourceVariable assign(Tensor tensor) + { + if (OriginalVar != null) + { + OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice); + return OriginalVar; + } + else + throw new RuntimeError("Operation doesn't support."); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs index 239780ea..510ef689 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -30,81 +30,28 @@ namespace Tensorflow { get { - var begin = new List(); - var end = new List(); - var strides = new List(); + var args = tensor_util.ParseSlices(slices); - var index = 0; - var (new_axis_mask, shrink_axis_mask) = (0, 0); - var (begin_mask, end_mask) = (0, 0); - var ellipsis_mask = 0; - - foreach (var s in slices) - { - if (s.IsNewAxis) - { - begin.Add(0); - end.Add(0); - strides.Add(1); - new_axis_mask |= (1 << index); - } - else if (s.IsEllipsis) - { - begin.Add(0); - end.Add(0); - strides.Add(1); - ellipsis_mask |= (1 << index); - } - else - { - if (s.Start.HasValue) - { - begin.Add(s.Start.Value); - } - else - { - begin.Add(0); - begin_mask |= (1 << index); - } - - if (s.Stop.HasValue) - { - end.Add(s.Stop.Value); - } - else - { - end.Add(0); - end_mask |= (1 << index); - } - - strides.Add(s.Step); - if (s.IsIndex) - shrink_axis_mask |= (1 << index); - } - - index += 1; - } - - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", args), scope => { string name = scope; - if (begin != null) + if (args.Begin != null) { var (packed_begin, packed_end, packed_strides) = - (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + (array_ops.stack(args.Begin), + array_ops.stack(args.End), + array_ops.stack(args.Strides)); return gen_array_ops.strided_slice( this, packed_begin, packed_end, packed_strides, - begin_mask: begin_mask, - end_mask: end_mask, - shrink_axis_mask: shrink_axis_mask, - new_axis_mask: new_axis_mask, - ellipsis_mask: ellipsis_mask, + begin_mask: args.BeginMask, + end_mask: args.EndMask, + shrink_axis_mask: args.ShrinkAxisMask, + new_axis_mask: args.NewAxisMask, + ellipsis_mask: args.EllipsisMask, name: name); } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 87f16380..7a665c23 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -16,6 +16,7 @@ using NumSharp; using System; +using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Eager; @@ -584,5 +585,75 @@ would not be rank 1.", tensor.op.get_attr("axis"))); return nd.ToString(); } } + + public static ParsedSliceArgs ParseSlices(Slice[] slices) + { + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slices) + { + if (s.IsNewAxis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + new_axis_mask |= (1 << index); + } + else if (s.IsEllipsis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + ellipsis_mask |= (1 << index); + } + else + { + if (s.Start.HasValue) + { + begin.Add(s.Start.Value); + } + else + { + begin.Add(0); + begin_mask |= (1 << index); + } + + if (s.Stop.HasValue) + { + end.Add(s.Stop.Value); + } + else + { + end.Add(0); + end_mask |= (1 << index); + } + + strides.Add(s.Step); + if (s.IsIndex) + shrink_axis_mask |= (1 << index); + } + + index += 1; + } + + return new ParsedSliceArgs + { + Begin = begin.ToArray(), + End = end.ToArray(), + Strides = strides.ToArray(), + BeginMask = begin_mask, + EndMask = end_mask, + EllipsisMask = ellipsis_mask, + ShrinkAxisMask = shrink_axis_mask, + NewAxisMask = new_axis_mask + }; + } } } diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index c30a0be2..a504c61b 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -89,6 +89,22 @@ namespace Tensorflow return assign_op; } + public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice) + { + _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value); + } + + void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null, + int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0) + { + var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value, + begin_mask: begin_mask, + end_mask: end_mask, + ellipsis_mask: ellipsis_mask, + new_axis_mask: new_axis_mask, + shrink_axis_mask: shrink_axis_mask); + } + public IVariableV1 assign_lazy_load(Tensor value, string name = null) { var value_tensor = ops.convert_to_tensor(value, dtype: dtype); diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs new file mode 100644 index 00000000..80521a29 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs @@ -0,0 +1,66 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class ResourceVariable + { + public Tensor this[params Slice[] slices] + { + get + { + var args = tensor_util.ParseSlices(slices); + + return tf_with(ops.name_scope(null, "strided_slice", args), scope => + { + string name = scope; + if (args.Begin != null) + { + (args.PackedBegin, args.PackedEnd, args.PackedStrides) = + (array_ops.stack(args.Begin), + array_ops.stack(args.End), + array_ops.stack(args.Strides)); + + var tensor = gen_array_ops.strided_slice( + this, + args.PackedBegin, + args.PackedEnd, + args.PackedStrides, + begin_mask: args.BeginMask, + end_mask: args.EndMask, + shrink_axis_mask: args.ShrinkAxisMask, + new_axis_mask: args.NewAxisMask, + ellipsis_mask: args.EllipsisMask, + name: name); + + tensor.OriginalVar = this; + tensor.OriginalVarSlice = args; + + return tensor; + } + + throw new NotImplementedException(""); + }); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index ce587397..6d79f906 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -40,21 +40,6 @@ namespace Tensorflow container: container, shared_name: shared_name); - public static Tensor assign(Tensor @ref, object value, - bool validate_shape = true, - bool use_locking = true, - string name = null) - { - if (@ref.dtype.is_ref_dtype()) - return gen_state_ops.assign(@ref, - value, - validate_shape: validate_shape, - use_locking: use_locking, - name: name); - - return @ref.assign((Tensor)value, name: name); - } - public static Tensor assign(T @ref, object value, bool validate_shape = true, bool use_locking = true, diff --git a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs index 2811b850..a81074b3 100644 --- a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs @@ -298,45 +298,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI tf.compat.v1.disable_eager_execution(); } - - /// - /// Assign tensor to slice of other tensor. - /// - [TestMethod] - public void TestAssignOfficial() - { - // example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ - - // python - // import tensorflow as tf - // A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) - // with tf.compat.v1.Session() as sess: - // sess.run(tf.compat.v1.global_variables_initializer()) - // print(sess.run(A[:2, :2])) # => [[1,2], [4,5]] - - // op = A[:2,:2].assign(22. * tf.ones((2, 2))) - // print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] - - // C# - // [[1,2,3], [4,5,6], [7,8,9]] - double[][] initial = new double[][] - { - new double[] { 1, 2, 3 }, - new double[] { 4, 5, 6 }, - new double[] { 7, 8, 9 } - }; - Tensor A = tf.Variable(initial, dtype: tf.float32); - // Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]] - Tensor result1 = A[":2", ":2"]; - Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray())); - Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray())); - - // An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible' - Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2))); - // Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]] - Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray())); - Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray())); - Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray())); - } } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs index 27371c41..c4bb4729 100644 --- a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using System.Linq; using static Tensorflow.Binding; @@ -47,6 +48,41 @@ namespace TensorFlowNET.UnitTest.Basics Assert.AreEqual(11f, (float)v1.numpy()); } + /// + /// Assign tensor to slice of other tensor. + /// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ + /// + [TestMethod] + public void SliceAssign() + { + NDArray nd = new float[,] + { + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } + }; + var x = tf.Variable(nd); + + // get slice form variable + var sliced = x[":2", ":2"]; + Assert.AreEqual(nd[0][":2"], sliced[0].numpy()); + Assert.AreEqual(nd[1][":2"], sliced[1].numpy()); + + // assign to the sliced tensor + sliced.assign(22 * tf.ones((2, 2))); + + // test assigned value + nd = new float[,] + { + { 22, 22, 3 }, + { 22, 22, 6 }, + { 7, 8, 9 } + }; + Assert.AreEqual(nd[0], x[0].numpy()); + Assert.AreEqual(nd[1], x[1].numpy()); + Assert.AreEqual(nd[2], x[2].numpy()); + } + [TestMethod] public void Accumulation() {