Browse Source

fix Variable[slice].assign() #653

tags/v0.30
Oceania2018 4 years ago
parent
commit
d75366ce22
13 changed files with 271 additions and 161 deletions
  1. +0
    -3
      src/TensorFlowNET.Core/APIs/tf.ops.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Contexts/Context.cs
  3. +24
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  4. +0
    -37
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  5. +21
    -0
      src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs
  6. +24
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs
  7. +11
    -64
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
  8. +71
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  9. +16
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  10. +66
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs
  11. +0
    -15
      src/TensorFlowNET.Core/Variables/state_ops.cs
  12. +0
    -40
      test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
  13. +36
    -0
      test/TensorFlowNET.UnitTest/Basics/VariableTest.cs

+ 0
- 3
src/TensorFlowNET.Core/APIs/tf.ops.cs View File

@@ -27,9 +27,6 @@ namespace Tensorflow
public void add_to_collections<T>(List<string> names, T value)
=> get_default_graph().add_to_collections(names, value);

public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);

public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);



+ 2
- 2
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -91,7 +91,7 @@ namespace Tensorflow.Contexts
context_switches.Pop();
}

[DebuggerStepThrough]
// [DebuggerStepThrough]
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors)
{
var shouldRunInEager = executing_eagerly()
@@ -115,7 +115,7 @@ namespace Tensorflow.Contexts
}
}

[DebuggerStepThrough]
// [DebuggerStepThrough]
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
Func<Tensors> eagerAction,
Action<Operation> recordGradient,


+ 24
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -593,6 +593,30 @@ namespace Tensorflow
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides);

public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
int begin_mask = 0,
int end_mask = 0,
int ellipsis_mask = 0,
int new_axis_mask = 0,
int shrink_axis_mask = 0,
string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new
{
input, begin, end, strides, value,
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResourceStridedSliceAssign", name,
null,
input, begin, end, strides, value,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides, value);

public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides,
int begin_mask = 0,
int end_mask = 0,


+ 0
- 37
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -34,43 +34,6 @@ namespace Tensorflow
name: name);
}

/// <summary>
///
/// </summary>
/// <param name="self"></param>
/// <param name="value"></param>
/// <param name="use_locking"></param>
/// <param name="read_value"></param>
/// <returns>
/// If `read_value` is `True`, this method will return the new value of the
/// variable after the assignment has completed.Otherwise, when in graph mode
/// it will return the `Operation` that does the assignment, and when in eager
/// mode it will return `None`.
/// </returns>
public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true)
{
var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype);
self.assert_is_compatible_with(value_tensor);
var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name);
if (read_value)
{
return self._lazy_read(assign_op);
}

return assign_op;
}

public static Operation _lazy_read(this Tensor self, Operation op)
{
variable_accessed(self);
throw new NotImplementedException();
}

public static void variable_accessed(this Tensor variable)
{
throw new NotImplementedException();
}

public static bool is_resource_variable(IVariableV1 var)
{
return var is ResourceVariable;


+ 21
- 0
src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class ParsedSliceArgs
{
public int[] Begin { get; set; }
public Tensor PackedBegin { get; set; }
public int[] End { get; set; }
public Tensor PackedEnd { get; set; }
public int[] Strides { get; set; }
public Tensor PackedStrides { get; set; }
public int BeginMask { get; set; }
public int EndMask { get; set; }
public int ShrinkAxisMask { get; set; }
public int NewAxisMask { get; set; }
public int EllipsisMask { get; set; }
}
}

+ 24
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs View File

@@ -0,0 +1,24 @@
using NumSharp;

namespace Tensorflow
{
public partial class Tensor
{
/// <summary>
/// Used to keep the original variable when slicing
/// </summary>
public ResourceVariable OriginalVar { get; set; }
public ParsedSliceArgs OriginalVarSlice { get; set; }

public ResourceVariable assign(Tensor tensor)
{
if (OriginalVar != null)
{
OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice);
return OriginalVar;
}
else
throw new RuntimeError("Operation doesn't support.");
}
}
}

+ 11
- 64
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -30,81 +30,28 @@ namespace Tensorflow
{
get
{
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
var args = tensor_util.ParseSlices(slices);

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slices)
{
if (s.IsNewAxis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
new_axis_mask |= (1 << index);
}
else if (s.IsEllipsis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
ellipsis_mask |= (1 << index);
}
else
{
if (s.Start.HasValue)
{
begin.Add(s.Start.Value);
}
else
{
begin.Add(0);
begin_mask |= (1 << index);
}

if (s.Stop.HasValue)
{
end.Add(s.Stop.Value);
}
else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(s.Step);
if (s.IsIndex)
shrink_axis_mask |= (1 << index);
}

index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", args), scope =>
{
string name = scope;
if (begin != null)
if (args.Begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
(array_ops.stack(args.Begin),
array_ops.stack(args.End),
array_ops.stack(args.Strides));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
begin_mask: args.BeginMask,
end_mask: args.EndMask,
shrink_axis_mask: args.ShrinkAxisMask,
new_axis_mask: args.NewAxisMask,
ellipsis_mask: args.EllipsisMask,
name: name);
}



+ 71
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -16,6 +16,7 @@

using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
@@ -584,5 +585,75 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return nd.ToString();
}
}

public static ParsedSliceArgs ParseSlices(Slice[] slices)
{
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slices)
{
if (s.IsNewAxis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
new_axis_mask |= (1 << index);
}
else if (s.IsEllipsis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
ellipsis_mask |= (1 << index);
}
else
{
if (s.Start.HasValue)
{
begin.Add(s.Start.Value);
}
else
{
begin.Add(0);
begin_mask |= (1 << index);
}

if (s.Stop.HasValue)
{
end.Add(s.Stop.Value);
}
else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(s.Step);
if (s.IsIndex)
shrink_axis_mask |= (1 << index);
}

index += 1;
}

return new ParsedSliceArgs
{
Begin = begin.ToArray(),
End = end.ToArray(),
Strides = strides.ToArray(),
BeginMask = begin_mask,
EndMask = end_mask,
EllipsisMask = ellipsis_mask,
ShrinkAxisMask = shrink_axis_mask,
NewAxisMask = new_axis_mask
};
}
}
}

+ 16
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -89,6 +89,22 @@ namespace Tensorflow
return assign_op;
}

public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
{
_strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
}

void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
{
var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
begin_mask: begin_mask,
end_mask: end_mask,
ellipsis_mask: ellipsis_mask,
new_axis_mask: new_axis_mask,
shrink_axis_mask: shrink_axis_mask);
}

public IVariableV1 assign_lazy_load(Tensor value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, dtype: dtype);


+ 66
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs View File

@@ -0,0 +1,66 @@
/*****************************************************************************
Copyright 2020 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class ResourceVariable
{
public Tensor this[params Slice[] slices]
{
get
{
var args = tensor_util.ParseSlices(slices);

return tf_with(ops.name_scope(null, "strided_slice", args), scope =>
{
string name = scope;
if (args.Begin != null)
{
(args.PackedBegin, args.PackedEnd, args.PackedStrides) =
(array_ops.stack(args.Begin),
array_ops.stack(args.End),
array_ops.stack(args.Strides));

var tensor = gen_array_ops.strided_slice(
this,
args.PackedBegin,
args.PackedEnd,
args.PackedStrides,
begin_mask: args.BeginMask,
end_mask: args.EndMask,
shrink_axis_mask: args.ShrinkAxisMask,
new_axis_mask: args.NewAxisMask,
ellipsis_mask: args.EllipsisMask,
name: name);

tensor.OriginalVar = this;
tensor.OriginalVarSlice = args;

return tensor;
}

throw new NotImplementedException("");
});
}
}
}
}

+ 0
- 15
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -40,21 +40,6 @@ namespace Tensorflow
container: container,
shared_name: shared_name);

public static Tensor assign(Tensor @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);

return @ref.assign((Tensor)value, name: name);
}

public static Tensor assign<T>(T @ref, object value,
bool validate_shape = true,
bool use_locking = true,


+ 0
- 40
test/TensorFlowNET.UnitTest/Basics/TensorTest.cs View File

@@ -298,45 +298,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI

tf.compat.v1.disable_eager_execution();
}

/// <summary>
/// Assign tensor to slice of other tensor.
/// </summary>
[TestMethod]
public void TestAssignOfficial()
{
// example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__

// python
// import tensorflow as tf
// A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
// with tf.compat.v1.Session() as sess:
// sess.run(tf.compat.v1.global_variables_initializer())
// print(sess.run(A[:2, :2])) # => [[1,2], [4,5]]

// op = A[:2,:2].assign(22. * tf.ones((2, 2)))
// print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]

// C#
// [[1,2,3], [4,5,6], [7,8,9]]
double[][] initial = new double[][]
{
new double[] { 1, 2, 3 },
new double[] { 4, 5, 6 },
new double[] { 7, 8, 9 }
};
Tensor A = tf.Variable(initial, dtype: tf.float32);
// Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]]
Tensor result1 = A[":2", ":2"];
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray<double>()));
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray<double>()));

// An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible'
Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2)));
// Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]]
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray<double>()));
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray<double>()));
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray<double>()));
}
}
}

+ 36
- 0
test/TensorFlowNET.UnitTest/Basics/VariableTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System.Linq;
using static Tensorflow.Binding;

@@ -47,6 +48,41 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual(11f, (float)v1.numpy());
}

/// <summary>
/// Assign tensor to slice of other tensor.
/// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
/// </summary>
[TestMethod]
public void SliceAssign()
{
NDArray nd = new float[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 },
{ 7, 8, 9 }
};
var x = tf.Variable(nd);

// get slice form variable
var sliced = x[":2", ":2"];
Assert.AreEqual(nd[0][":2"], sliced[0].numpy());
Assert.AreEqual(nd[1][":2"], sliced[1].numpy());

// assign to the sliced tensor
sliced.assign(22 * tf.ones((2, 2)));

// test assigned value
nd = new float[,]
{
{ 22, 22, 3 },
{ 22, 22, 6 },
{ 7, 8, 9 }
};
Assert.AreEqual(nd[0], x[0].numpy());
Assert.AreEqual(nd[1], x[1].numpy());
Assert.AreEqual(nd[2], x[2].numpy());
}

[TestMethod]
public void Accumulation()
{


Loading…
Cancel
Save