@@ -27,9 +27,6 @@ namespace Tensorflow | |||||
public void add_to_collections<T>(List<string> names, T value) | public void add_to_collections<T>(List<string> names, T value) | ||||
=> get_default_graph().add_to_collections(names, value); | => get_default_graph().add_to_collections(names, value); | ||||
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | |||||
=> state_ops.assign(@ref, value, validate_shape, use_locking, name); | |||||
public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | ||||
=> state_ops.assign(@ref, value, validate_shape, use_locking, name); | => state_ops.assign(@ref, value, validate_shape, use_locking, name); | ||||
@@ -91,7 +91,7 @@ namespace Tensorflow.Contexts | |||||
context_switches.Pop(); | context_switches.Pop(); | ||||
} | } | ||||
[DebuggerStepThrough] | |||||
// [DebuggerStepThrough] | |||||
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors) | public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors) | ||||
{ | { | ||||
var shouldRunInEager = executing_eagerly() | var shouldRunInEager = executing_eagerly() | ||||
@@ -115,7 +115,7 @@ namespace Tensorflow.Contexts | |||||
} | } | ||||
} | } | ||||
[DebuggerStepThrough] | |||||
// [DebuggerStepThrough] | |||||
public Tensors RunInAutoMode2(Func<Tensors> graphAction, | public Tensors RunInAutoMode2(Func<Tensors> graphAction, | ||||
Func<Tensors> eagerAction, | Func<Tensors> eagerAction, | ||||
Action<Operation> recordGradient, | Action<Operation> recordGradient, | ||||
@@ -593,6 +593,30 @@ namespace Tensorflow | |||||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | ||||
input, begin, end, strides); | input, begin, end, strides); | ||||
public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, | |||||
int begin_mask = 0, | |||||
int end_mask = 0, | |||||
int ellipsis_mask = 0, | |||||
int new_axis_mask = 0, | |||||
int shrink_axis_mask = 0, | |||||
string name = null) | |||||
=> tf.Context.RunInAutoMode(() | |||||
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new | |||||
{ | |||||
input, begin, end, strides, value, | |||||
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask | |||||
}).output, () | |||||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"ResourceStridedSliceAssign", name, | |||||
null, | |||||
input, begin, end, strides, value, | |||||
"begin_mask", begin_mask, | |||||
"end_mask", end_mask, | |||||
"ellipsis_mask", ellipsis_mask, | |||||
"new_axis_mask", new_axis_mask, | |||||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||||
input, begin, end, strides, value); | |||||
public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | ||||
int begin_mask = 0, | int begin_mask = 0, | ||||
int end_mask = 0, | int end_mask = 0, | ||||
@@ -34,43 +34,6 @@ namespace Tensorflow | |||||
name: name); | name: name); | ||||
} | } | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="self"></param> | |||||
/// <param name="value"></param> | |||||
/// <param name="use_locking"></param> | |||||
/// <param name="read_value"></param> | |||||
/// <returns> | |||||
/// If `read_value` is `True`, this method will return the new value of the | |||||
/// variable after the assignment has completed.Otherwise, when in graph mode | |||||
/// it will return the `Operation` that does the assignment, and when in eager | |||||
/// mode it will return `None`. | |||||
/// </returns> | |||||
public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true) | |||||
{ | |||||
var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype); | |||||
self.assert_is_compatible_with(value_tensor); | |||||
var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name); | |||||
if (read_value) | |||||
{ | |||||
return self._lazy_read(assign_op); | |||||
} | |||||
return assign_op; | |||||
} | |||||
public static Operation _lazy_read(this Tensor self, Operation op) | |||||
{ | |||||
variable_accessed(self); | |||||
throw new NotImplementedException(); | |||||
} | |||||
public static void variable_accessed(this Tensor variable) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public static bool is_resource_variable(IVariableV1 var) | public static bool is_resource_variable(IVariableV1 var) | ||||
{ | { | ||||
return var is ResourceVariable; | return var is ResourceVariable; | ||||
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class ParsedSliceArgs | |||||
{ | |||||
public int[] Begin { get; set; } | |||||
public Tensor PackedBegin { get; set; } | |||||
public int[] End { get; set; } | |||||
public Tensor PackedEnd { get; set; } | |||||
public int[] Strides { get; set; } | |||||
public Tensor PackedStrides { get; set; } | |||||
public int BeginMask { get; set; } | |||||
public int EndMask { get; set; } | |||||
public int ShrinkAxisMask { get; set; } | |||||
public int NewAxisMask { get; set; } | |||||
public int EllipsisMask { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,24 @@ | |||||
using NumSharp; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Tensor | |||||
{ | |||||
/// <summary> | |||||
/// Used to keep the original variable when slicing | |||||
/// </summary> | |||||
public ResourceVariable OriginalVar { get; set; } | |||||
public ParsedSliceArgs OriginalVarSlice { get; set; } | |||||
public ResourceVariable assign(Tensor tensor) | |||||
{ | |||||
if (OriginalVar != null) | |||||
{ | |||||
OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice); | |||||
return OriginalVar; | |||||
} | |||||
else | |||||
throw new RuntimeError("Operation doesn't support."); | |||||
} | |||||
} | |||||
} |
@@ -30,81 +30,28 @@ namespace Tensorflow | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
var begin = new List<int>(); | |||||
var end = new List<int>(); | |||||
var strides = new List<int>(); | |||||
var args = tensor_util.ParseSlices(slices); | |||||
var index = 0; | |||||
var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||||
var (begin_mask, end_mask) = (0, 0); | |||||
var ellipsis_mask = 0; | |||||
foreach (var s in slices) | |||||
{ | |||||
if (s.IsNewAxis) | |||||
{ | |||||
begin.Add(0); | |||||
end.Add(0); | |||||
strides.Add(1); | |||||
new_axis_mask |= (1 << index); | |||||
} | |||||
else if (s.IsEllipsis) | |||||
{ | |||||
begin.Add(0); | |||||
end.Add(0); | |||||
strides.Add(1); | |||||
ellipsis_mask |= (1 << index); | |||||
} | |||||
else | |||||
{ | |||||
if (s.Start.HasValue) | |||||
{ | |||||
begin.Add(s.Start.Value); | |||||
} | |||||
else | |||||
{ | |||||
begin.Add(0); | |||||
begin_mask |= (1 << index); | |||||
} | |||||
if (s.Stop.HasValue) | |||||
{ | |||||
end.Add(s.Stop.Value); | |||||
} | |||||
else | |||||
{ | |||||
end.Add(0); | |||||
end_mask |= (1 << index); | |||||
} | |||||
strides.Add(s.Step); | |||||
if (s.IsIndex) | |||||
shrink_axis_mask |= (1 << index); | |||||
} | |||||
index += 1; | |||||
} | |||||
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
return tf_with(ops.name_scope(null, "strided_slice", args), scope => | |||||
{ | { | ||||
string name = scope; | string name = scope; | ||||
if (begin != null) | |||||
if (args.Begin != null) | |||||
{ | { | ||||
var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
(array_ops.stack(begin.ToArray()), | |||||
array_ops.stack(end.ToArray()), | |||||
array_ops.stack(strides.ToArray())); | |||||
(array_ops.stack(args.Begin), | |||||
array_ops.stack(args.End), | |||||
array_ops.stack(args.Strides)); | |||||
return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
this, | this, | ||||
packed_begin, | packed_begin, | ||||
packed_end, | packed_end, | ||||
packed_strides, | packed_strides, | ||||
begin_mask: begin_mask, | |||||
end_mask: end_mask, | |||||
shrink_axis_mask: shrink_axis_mask, | |||||
new_axis_mask: new_axis_mask, | |||||
ellipsis_mask: ellipsis_mask, | |||||
begin_mask: args.BeginMask, | |||||
end_mask: args.EndMask, | |||||
shrink_axis_mask: args.ShrinkAxisMask, | |||||
new_axis_mask: args.NewAxisMask, | |||||
ellipsis_mask: args.EllipsisMask, | |||||
name: name); | name: name); | ||||
} | } | ||||
@@ -16,6 +16,7 @@ | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
@@ -584,5 +585,75 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
return nd.ToString(); | return nd.ToString(); | ||||
} | } | ||||
} | } | ||||
public static ParsedSliceArgs ParseSlices(Slice[] slices) | |||||
{ | |||||
var begin = new List<int>(); | |||||
var end = new List<int>(); | |||||
var strides = new List<int>(); | |||||
var index = 0; | |||||
var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||||
var (begin_mask, end_mask) = (0, 0); | |||||
var ellipsis_mask = 0; | |||||
foreach (var s in slices) | |||||
{ | |||||
if (s.IsNewAxis) | |||||
{ | |||||
begin.Add(0); | |||||
end.Add(0); | |||||
strides.Add(1); | |||||
new_axis_mask |= (1 << index); | |||||
} | |||||
else if (s.IsEllipsis) | |||||
{ | |||||
begin.Add(0); | |||||
end.Add(0); | |||||
strides.Add(1); | |||||
ellipsis_mask |= (1 << index); | |||||
} | |||||
else | |||||
{ | |||||
if (s.Start.HasValue) | |||||
{ | |||||
begin.Add(s.Start.Value); | |||||
} | |||||
else | |||||
{ | |||||
begin.Add(0); | |||||
begin_mask |= (1 << index); | |||||
} | |||||
if (s.Stop.HasValue) | |||||
{ | |||||
end.Add(s.Stop.Value); | |||||
} | |||||
else | |||||
{ | |||||
end.Add(0); | |||||
end_mask |= (1 << index); | |||||
} | |||||
strides.Add(s.Step); | |||||
if (s.IsIndex) | |||||
shrink_axis_mask |= (1 << index); | |||||
} | |||||
index += 1; | |||||
} | |||||
return new ParsedSliceArgs | |||||
{ | |||||
Begin = begin.ToArray(), | |||||
End = end.ToArray(), | |||||
Strides = strides.ToArray(), | |||||
BeginMask = begin_mask, | |||||
EndMask = end_mask, | |||||
EllipsisMask = ellipsis_mask, | |||||
ShrinkAxisMask = shrink_axis_mask, | |||||
NewAxisMask = new_axis_mask | |||||
}; | |||||
} | |||||
} | } | ||||
} | } |
@@ -89,6 +89,22 @@ namespace Tensorflow | |||||
return assign_op; | return assign_op; | ||||
} | } | ||||
public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice) | |||||
{ | |||||
_strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value); | |||||
} | |||||
void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null, | |||||
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0) | |||||
{ | |||||
var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value, | |||||
begin_mask: begin_mask, | |||||
end_mask: end_mask, | |||||
ellipsis_mask: ellipsis_mask, | |||||
new_axis_mask: new_axis_mask, | |||||
shrink_axis_mask: shrink_axis_mask); | |||||
} | |||||
public IVariableV1 assign_lazy_load(Tensor value, string name = null) | public IVariableV1 assign_lazy_load(Tensor value, string name = null) | ||||
{ | { | ||||
var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | ||||
@@ -0,0 +1,66 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 Haiping Chen. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class ResourceVariable | |||||
{ | |||||
public Tensor this[params Slice[] slices] | |||||
{ | |||||
get | |||||
{ | |||||
var args = tensor_util.ParseSlices(slices); | |||||
return tf_with(ops.name_scope(null, "strided_slice", args), scope => | |||||
{ | |||||
string name = scope; | |||||
if (args.Begin != null) | |||||
{ | |||||
(args.PackedBegin, args.PackedEnd, args.PackedStrides) = | |||||
(array_ops.stack(args.Begin), | |||||
array_ops.stack(args.End), | |||||
array_ops.stack(args.Strides)); | |||||
var tensor = gen_array_ops.strided_slice( | |||||
this, | |||||
args.PackedBegin, | |||||
args.PackedEnd, | |||||
args.PackedStrides, | |||||
begin_mask: args.BeginMask, | |||||
end_mask: args.EndMask, | |||||
shrink_axis_mask: args.ShrinkAxisMask, | |||||
new_axis_mask: args.NewAxisMask, | |||||
ellipsis_mask: args.EllipsisMask, | |||||
name: name); | |||||
tensor.OriginalVar = this; | |||||
tensor.OriginalVarSlice = args; | |||||
return tensor; | |||||
} | |||||
throw new NotImplementedException(""); | |||||
}); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -40,21 +40,6 @@ namespace Tensorflow | |||||
container: container, | container: container, | ||||
shared_name: shared_name); | shared_name: shared_name); | ||||
public static Tensor assign(Tensor @ref, object value, | |||||
bool validate_shape = true, | |||||
bool use_locking = true, | |||||
string name = null) | |||||
{ | |||||
if (@ref.dtype.is_ref_dtype()) | |||||
return gen_state_ops.assign(@ref, | |||||
value, | |||||
validate_shape: validate_shape, | |||||
use_locking: use_locking, | |||||
name: name); | |||||
return @ref.assign((Tensor)value, name: name); | |||||
} | |||||
public static Tensor assign<T>(T @ref, object value, | public static Tensor assign<T>(T @ref, object value, | ||||
bool validate_shape = true, | bool validate_shape = true, | ||||
bool use_locking = true, | bool use_locking = true, | ||||
@@ -298,45 +298,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
tf.compat.v1.disable_eager_execution(); | tf.compat.v1.disable_eager_execution(); | ||||
} | } | ||||
/// <summary> | |||||
/// Assign tensor to slice of other tensor. | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void TestAssignOfficial() | |||||
{ | |||||
// example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ | |||||
// python | |||||
// import tensorflow as tf | |||||
// A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) | |||||
// with tf.compat.v1.Session() as sess: | |||||
// sess.run(tf.compat.v1.global_variables_initializer()) | |||||
// print(sess.run(A[:2, :2])) # => [[1,2], [4,5]] | |||||
// op = A[:2,:2].assign(22. * tf.ones((2, 2))) | |||||
// print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] | |||||
// C# | |||||
// [[1,2,3], [4,5,6], [7,8,9]] | |||||
double[][] initial = new double[][] | |||||
{ | |||||
new double[] { 1, 2, 3 }, | |||||
new double[] { 4, 5, 6 }, | |||||
new double[] { 7, 8, 9 } | |||||
}; | |||||
Tensor A = tf.Variable(initial, dtype: tf.float32); | |||||
// Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]] | |||||
Tensor result1 = A[":2", ":2"]; | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray<double>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray<double>())); | |||||
// An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible' | |||||
Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2))); | |||||
// Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]] | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray<double>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray<double>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray<double>())); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | |||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -47,6 +48,41 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
Assert.AreEqual(11f, (float)v1.numpy()); | Assert.AreEqual(11f, (float)v1.numpy()); | ||||
} | } | ||||
/// <summary> | |||||
/// Assign tensor to slice of other tensor. | |||||
/// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void SliceAssign() | |||||
{ | |||||
NDArray nd = new float[,] | |||||
{ | |||||
{ 1, 2, 3 }, | |||||
{ 4, 5, 6 }, | |||||
{ 7, 8, 9 } | |||||
}; | |||||
var x = tf.Variable(nd); | |||||
// get slice form variable | |||||
var sliced = x[":2", ":2"]; | |||||
Assert.AreEqual(nd[0][":2"], sliced[0].numpy()); | |||||
Assert.AreEqual(nd[1][":2"], sliced[1].numpy()); | |||||
// assign to the sliced tensor | |||||
sliced.assign(22 * tf.ones((2, 2))); | |||||
// test assigned value | |||||
nd = new float[,] | |||||
{ | |||||
{ 22, 22, 3 }, | |||||
{ 22, 22, 6 }, | |||||
{ 7, 8, 9 } | |||||
}; | |||||
Assert.AreEqual(nd[0], x[0].numpy()); | |||||
Assert.AreEqual(nd[1], x[1].numpy()); | |||||
Assert.AreEqual(nd[2], x[2].numpy()); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Accumulation() | public void Accumulation() | ||||
{ | { | ||||