From 66f390a910619c730778917319d4f54316123e00 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 8 Feb 2019 07:00:55 -0600 Subject: [PATCH] add operator overload for RefVariable. --- .../IReturnTensorOrOperation.cs | 14 ++++++++++ .../Operations/Operation.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 4 +-- .../Variables/RefVariable.Operators.cs | 18 +++++++++--- .../Variables/RefVariable.cs | 28 +++++++++++++++++++ test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 4 +++ 6 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 src/TensorFlowNET.Core/IReturnTensorOrOperation.cs diff --git a/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs b/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs new file mode 100644 index 00000000..51c840ac --- /dev/null +++ b/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// in order to limit function return value + /// is Tensor or Operation + /// + public interface IReturnTensorOrOperation + { + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 769c7fb3..4d268c6d 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -7,7 +7,7 @@ using System.Text; namespace Tensorflow { - public partial class Operation + public partial class Operation : IReturnTensorOrOperation { private readonly IntPtr _handle; // _c_op in python diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 90b7c2dc..ab7d3304 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public partial class Tensor : IDisposable + public partial class Tensor : IDisposable, IReturnTensorOrOperation { private readonly IntPtr _handle; @@ -258,7 +258,7 @@ namespace Tensorflow } } - return $"{name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; + return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; } public void Dispose() diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs index 85fb19f8..dc510fcc 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -6,11 +6,21 @@ namespace Tensorflow { public partial class RefVariable { - public static Tensor operator +(RefVariable t1, int t2) + public static Tensor operator +(RefVariable x, int y) => op_helper("add", x, y); + public static Tensor operator +(RefVariable x, float y) => op_helper("add", x, y); + public static Tensor operator +(RefVariable x, double y) => op_helper("add", x, y); + + public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y); + public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y); + public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); + + private static Tensor op_helper(string default_name, RefVariable x, T y) { - var tensor1 = t1._AsTensor(); - var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y"); - return gen_math_ops.add(tensor1, tensor2); + var tensor1 = x.value(); + return Python.with(new ops.name_scope("", default_name, new object[] { tensor1, y }), scope => { + var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y"); + return gen_math_ops.add(tensor1, tensor2, scope); + }); } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 52d5f056..39a8d909 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -171,6 +171,34 @@ namespace Tensorflow return op; } + /// + /// Assigns a new value to the variable. + /// + /// The new value for this variable. + /// If `True`, use locking during the assignment. + /// The name of the operation to be created + /// + /// if True, will return something which evaluates to the + /// new value of the variable; if False will return the assign op. + /// + /// + /// A `Tensor` that will hold the new value of this variable after + /// the assignment has completed. + /// + public T assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) + where T : IReturnTensorOrOperation + { + var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); + if (read_value) + return (T)Convert.ChangeType(assign, typeof(T)); + return (T)Convert.ChangeType(assign.op, typeof(T)); + } + + public Tensor assign(Tensor value, bool use_locking = false, string name = "") + { + return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); + } + public override string ToString() { return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 1bb5fc3d..99e7ee20 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -15,7 +15,11 @@ namespace TensorFlowNET.UnitTest var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); + var inc_v1 = v1.assign(v1 + 1.0f); + var dec_v2 = v2.assign(v2 - 1.0f); + // Add an op to initialize the variables. + var init_op = tf.global_variables_initializer(); } } }