@@ -27,9 +27,6 @@ namespace Tensorflow | |||
public void add_to_collections<T>(List<string> names, T value) | |||
=> get_default_graph().add_to_collections(names, value); | |||
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | |||
=> state_ops.assign(@ref, value, validate_shape, use_locking, name); | |||
public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | |||
=> state_ops.assign(@ref, value, validate_shape, use_locking, name); | |||
@@ -91,7 +91,7 @@ namespace Tensorflow.Contexts | |||
context_switches.Pop(); | |||
} | |||
[DebuggerStepThrough] | |||
// [DebuggerStepThrough] | |||
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors) | |||
{ | |||
var shouldRunInEager = executing_eagerly() | |||
@@ -115,7 +115,7 @@ namespace Tensorflow.Contexts | |||
} | |||
} | |||
[DebuggerStepThrough] | |||
// [DebuggerStepThrough] | |||
public Tensors RunInAutoMode2(Func<Tensors> graphAction, | |||
Func<Tensors> eagerAction, | |||
Action<Operation> recordGradient, | |||
@@ -593,6 +593,30 @@ namespace Tensorflow | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
input, begin, end, strides); | |||
public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, | |||
int begin_mask = 0, | |||
int end_mask = 0, | |||
int ellipsis_mask = 0, | |||
int new_axis_mask = 0, | |||
int shrink_axis_mask = 0, | |||
string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new | |||
{ | |||
input, begin, end, strides, value, | |||
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResourceStridedSliceAssign", name, | |||
null, | |||
input, begin, end, strides, value, | |||
"begin_mask", begin_mask, | |||
"end_mask", end_mask, | |||
"ellipsis_mask", ellipsis_mask, | |||
"new_axis_mask", new_axis_mask, | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
input, begin, end, strides, value); | |||
public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | |||
int begin_mask = 0, | |||
int end_mask = 0, | |||
@@ -34,43 +34,6 @@ namespace Tensorflow | |||
name: name); | |||
} | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="self"></param> | |||
/// <param name="value"></param> | |||
/// <param name="use_locking"></param> | |||
/// <param name="read_value"></param> | |||
/// <returns> | |||
/// If `read_value` is `True`, this method will return the new value of the | |||
/// variable after the assignment has completed.Otherwise, when in graph mode | |||
/// it will return the `Operation` that does the assignment, and when in eager | |||
/// mode it will return `None`. | |||
/// </returns> | |||
public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true) | |||
{ | |||
var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype); | |||
self.assert_is_compatible_with(value_tensor); | |||
var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name); | |||
if (read_value) | |||
{ | |||
return self._lazy_read(assign_op); | |||
} | |||
return assign_op; | |||
} | |||
public static Operation _lazy_read(this Tensor self, Operation op) | |||
{ | |||
variable_accessed(self); | |||
throw new NotImplementedException(); | |||
} | |||
public static void variable_accessed(this Tensor variable) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public static bool is_resource_variable(IVariableV1 var) | |||
{ | |||
return var is ResourceVariable; | |||
@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class ParsedSliceArgs | |||
{ | |||
public int[] Begin { get; set; } | |||
public Tensor PackedBegin { get; set; } | |||
public int[] End { get; set; } | |||
public Tensor PackedEnd { get; set; } | |||
public int[] Strides { get; set; } | |||
public Tensor PackedStrides { get; set; } | |||
public int BeginMask { get; set; } | |||
public int EndMask { get; set; } | |||
public int ShrinkAxisMask { get; set; } | |||
public int NewAxisMask { get; set; } | |||
public int EllipsisMask { get; set; } | |||
} | |||
} |
@@ -0,0 +1,24 @@ | |||
using NumSharp; | |||
namespace Tensorflow | |||
{ | |||
public partial class Tensor | |||
{ | |||
/// <summary> | |||
/// Used to keep the original variable when slicing | |||
/// </summary> | |||
public ResourceVariable OriginalVar { get; set; } | |||
public ParsedSliceArgs OriginalVarSlice { get; set; } | |||
public ResourceVariable assign(Tensor tensor) | |||
{ | |||
if (OriginalVar != null) | |||
{ | |||
OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice); | |||
return OriginalVar; | |||
} | |||
else | |||
throw new RuntimeError("Operation doesn't support."); | |||
} | |||
} | |||
} |
@@ -30,81 +30,28 @@ namespace Tensorflow | |||
{ | |||
get | |||
{ | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
var args = tensor_util.ParseSlices(slices); | |||
var index = 0; | |||
var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||
var (begin_mask, end_mask) = (0, 0); | |||
var ellipsis_mask = 0; | |||
foreach (var s in slices) | |||
{ | |||
if (s.IsNewAxis) | |||
{ | |||
begin.Add(0); | |||
end.Add(0); | |||
strides.Add(1); | |||
new_axis_mask |= (1 << index); | |||
} | |||
else if (s.IsEllipsis) | |||
{ | |||
begin.Add(0); | |||
end.Add(0); | |||
strides.Add(1); | |||
ellipsis_mask |= (1 << index); | |||
} | |||
else | |||
{ | |||
if (s.Start.HasValue) | |||
{ | |||
begin.Add(s.Start.Value); | |||
} | |||
else | |||
{ | |||
begin.Add(0); | |||
begin_mask |= (1 << index); | |||
} | |||
if (s.Stop.HasValue) | |||
{ | |||
end.Add(s.Stop.Value); | |||
} | |||
else | |||
{ | |||
end.Add(0); | |||
end_mask |= (1 << index); | |||
} | |||
strides.Add(s.Step); | |||
if (s.IsIndex) | |||
shrink_axis_mask |= (1 << index); | |||
} | |||
index += 1; | |||
} | |||
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
return tf_with(ops.name_scope(null, "strided_slice", args), scope => | |||
{ | |||
string name = scope; | |||
if (begin != null) | |||
if (args.Begin != null) | |||
{ | |||
var (packed_begin, packed_end, packed_strides) = | |||
(array_ops.stack(begin.ToArray()), | |||
array_ops.stack(end.ToArray()), | |||
array_ops.stack(strides.ToArray())); | |||
(array_ops.stack(args.Begin), | |||
array_ops.stack(args.End), | |||
array_ops.stack(args.Strides)); | |||
return gen_array_ops.strided_slice( | |||
this, | |||
packed_begin, | |||
packed_end, | |||
packed_strides, | |||
begin_mask: begin_mask, | |||
end_mask: end_mask, | |||
shrink_axis_mask: shrink_axis_mask, | |||
new_axis_mask: new_axis_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
begin_mask: args.BeginMask, | |||
end_mask: args.EndMask, | |||
shrink_axis_mask: args.ShrinkAxisMask, | |||
new_axis_mask: args.NewAxisMask, | |||
ellipsis_mask: args.EllipsisMask, | |||
name: name); | |||
} | |||
@@ -16,6 +16,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
@@ -584,5 +585,75 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
return nd.ToString(); | |||
} | |||
} | |||
public static ParsedSliceArgs ParseSlices(Slice[] slices) | |||
{ | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
var index = 0; | |||
var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||
var (begin_mask, end_mask) = (0, 0); | |||
var ellipsis_mask = 0; | |||
foreach (var s in slices) | |||
{ | |||
if (s.IsNewAxis) | |||
{ | |||
begin.Add(0); | |||
end.Add(0); | |||
strides.Add(1); | |||
new_axis_mask |= (1 << index); | |||
} | |||
else if (s.IsEllipsis) | |||
{ | |||
begin.Add(0); | |||
end.Add(0); | |||
strides.Add(1); | |||
ellipsis_mask |= (1 << index); | |||
} | |||
else | |||
{ | |||
if (s.Start.HasValue) | |||
{ | |||
begin.Add(s.Start.Value); | |||
} | |||
else | |||
{ | |||
begin.Add(0); | |||
begin_mask |= (1 << index); | |||
} | |||
if (s.Stop.HasValue) | |||
{ | |||
end.Add(s.Stop.Value); | |||
} | |||
else | |||
{ | |||
end.Add(0); | |||
end_mask |= (1 << index); | |||
} | |||
strides.Add(s.Step); | |||
if (s.IsIndex) | |||
shrink_axis_mask |= (1 << index); | |||
} | |||
index += 1; | |||
} | |||
return new ParsedSliceArgs | |||
{ | |||
Begin = begin.ToArray(), | |||
End = end.ToArray(), | |||
Strides = strides.ToArray(), | |||
BeginMask = begin_mask, | |||
EndMask = end_mask, | |||
EllipsisMask = ellipsis_mask, | |||
ShrinkAxisMask = shrink_axis_mask, | |||
NewAxisMask = new_axis_mask | |||
}; | |||
} | |||
} | |||
} |
@@ -89,6 +89,22 @@ namespace Tensorflow | |||
return assign_op; | |||
} | |||
public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice) | |||
{ | |||
_strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value); | |||
} | |||
void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null, | |||
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0) | |||
{ | |||
var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value, | |||
begin_mask: begin_mask, | |||
end_mask: end_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
new_axis_mask: new_axis_mask, | |||
shrink_axis_mask: shrink_axis_mask); | |||
} | |||
public IVariableV1 assign_lazy_load(Tensor value, string name = null) | |||
{ | |||
var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | |||
@@ -0,0 +1,66 @@ | |||
/***************************************************************************** | |||
Copyright 2020 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public partial class ResourceVariable | |||
{ | |||
public Tensor this[params Slice[] slices] | |||
{ | |||
get | |||
{ | |||
var args = tensor_util.ParseSlices(slices); | |||
return tf_with(ops.name_scope(null, "strided_slice", args), scope => | |||
{ | |||
string name = scope; | |||
if (args.Begin != null) | |||
{ | |||
(args.PackedBegin, args.PackedEnd, args.PackedStrides) = | |||
(array_ops.stack(args.Begin), | |||
array_ops.stack(args.End), | |||
array_ops.stack(args.Strides)); | |||
var tensor = gen_array_ops.strided_slice( | |||
this, | |||
args.PackedBegin, | |||
args.PackedEnd, | |||
args.PackedStrides, | |||
begin_mask: args.BeginMask, | |||
end_mask: args.EndMask, | |||
shrink_axis_mask: args.ShrinkAxisMask, | |||
new_axis_mask: args.NewAxisMask, | |||
ellipsis_mask: args.EllipsisMask, | |||
name: name); | |||
tensor.OriginalVar = this; | |||
tensor.OriginalVarSlice = args; | |||
return tensor; | |||
} | |||
throw new NotImplementedException(""); | |||
}); | |||
} | |||
} | |||
} | |||
} |
@@ -40,21 +40,6 @@ namespace Tensorflow | |||
container: container, | |||
shared_name: shared_name); | |||
public static Tensor assign(Tensor @ref, object value, | |||
bool validate_shape = true, | |||
bool use_locking = true, | |||
string name = null) | |||
{ | |||
if (@ref.dtype.is_ref_dtype()) | |||
return gen_state_ops.assign(@ref, | |||
value, | |||
validate_shape: validate_shape, | |||
use_locking: use_locking, | |||
name: name); | |||
return @ref.assign((Tensor)value, name: name); | |||
} | |||
public static Tensor assign<T>(T @ref, object value, | |||
bool validate_shape = true, | |||
bool use_locking = true, | |||
@@ -298,45 +298,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
tf.compat.v1.disable_eager_execution(); | |||
} | |||
/// <summary> | |||
/// Assign tensor to slice of other tensor. | |||
/// </summary> | |||
[TestMethod] | |||
public void TestAssignOfficial() | |||
{ | |||
// example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ | |||
// python | |||
// import tensorflow as tf | |||
// A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) | |||
// with tf.compat.v1.Session() as sess: | |||
// sess.run(tf.compat.v1.global_variables_initializer()) | |||
// print(sess.run(A[:2, :2])) # => [[1,2], [4,5]] | |||
// op = A[:2,:2].assign(22. * tf.ones((2, 2))) | |||
// print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] | |||
// C# | |||
// [[1,2,3], [4,5,6], [7,8,9]] | |||
double[][] initial = new double[][] | |||
{ | |||
new double[] { 1, 2, 3 }, | |||
new double[] { 4, 5, 6 }, | |||
new double[] { 7, 8, 9 } | |||
}; | |||
Tensor A = tf.Variable(initial, dtype: tf.float32); | |||
// Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]] | |||
Tensor result1 = A[":2", ":2"]; | |||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray<double>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray<double>())); | |||
// An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible' | |||
Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2))); | |||
// Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]] | |||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray<double>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray<double>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray<double>())); | |||
} | |||
} | |||
} |
@@ -1,4 +1,5 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
@@ -47,6 +48,41 @@ namespace TensorFlowNET.UnitTest.Basics | |||
Assert.AreEqual(11f, (float)v1.numpy()); | |||
} | |||
/// <summary> | |||
/// Assign tensor to slice of other tensor. | |||
/// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ | |||
/// </summary> | |||
[TestMethod] | |||
public void SliceAssign() | |||
{ | |||
NDArray nd = new float[,] | |||
{ | |||
{ 1, 2, 3 }, | |||
{ 4, 5, 6 }, | |||
{ 7, 8, 9 } | |||
}; | |||
var x = tf.Variable(nd); | |||
// get slice form variable | |||
var sliced = x[":2", ":2"]; | |||
Assert.AreEqual(nd[0][":2"], sliced[0].numpy()); | |||
Assert.AreEqual(nd[1][":2"], sliced[1].numpy()); | |||
// assign to the sliced tensor | |||
sliced.assign(22 * tf.ones((2, 2))); | |||
// test assigned value | |||
nd = new float[,] | |||
{ | |||
{ 22, 22, 3 }, | |||
{ 22, 22, 6 }, | |||
{ 7, 8, 9 } | |||
}; | |||
Assert.AreEqual(nd[0], x[0].numpy()); | |||
Assert.AreEqual(nd[1], x[1].numpy()); | |||
Assert.AreEqual(nd[2], x[2].numpy()); | |||
} | |||
[TestMethod] | |||
public void Accumulation() | |||
{ | |||