diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index dfc55ace..c30a0be2 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -82,15 +82,22 @@ namespace Tensorflow var value_tensor = ops.convert_to_tensor(value, dtype: dtype); var assign_op = gen_resource_variable_ops.assign_variable_op( handle, value_tensor, name: name); + if (read_value) - { return gen_resource_variable_ops.read_variable_op(handle, dtype); - // var variable = _lazy_read(assign_op, value_tensor); - // return variable; - } + return assign_op; } + public IVariableV1 assign_lazy_load(Tensor value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, dtype: dtype); + var assign_op = gen_resource_variable_ops.assign_variable_op( + handle, value_tensor, name: name); + var variable = _lazy_read(assign_op, value_tensor); + return variable; + } + public Tensor value() => GraphElement ?? _read_variable_op(); @@ -157,6 +164,25 @@ namespace Tensorflow return assign_add_op; } + public Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true) + { + var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle, + ops.convert_to_tensor(delta, dtype: dtype), name: name); + + if (read_value) + return gen_resource_variable_ops.read_variable_op(handle, dtype); + // return _lazy_read(assign_add_op); + return assign_sub_op; + } + + public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null) + { + var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle, + ops.convert_to_tensor(delta, dtype: dtype), name: name); + + return _lazy_read(assign_sub_op, delta); + } + public override string ToString() { if (tf.Context.executing_eagerly()) diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index cd5afb79..4e1b7a96 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -47,7 +47,10 @@ namespace Tensorflow TF_DataType dtype { get; } TensorShape shape { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); + Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true); + IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true); + IVariableV1 assign_lazy_load(Tensor value, string name = null); Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); NDArray numpy(); } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 849f3157..3aad10bc 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -23,6 +23,7 @@ using static Tensorflow.Binding; namespace Tensorflow { + [Obsolete] public partial class RefVariable : IVariableV1, IProtoBuf { protected string _name; @@ -428,5 +429,20 @@ namespace Tensorflow public NDArray numpy() => throw new RuntimeError("Graph mode can't use numpy()."); + + public Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true) + { + throw new NotImplementedException(); + } + + public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null) + { + throw new NotImplementedException(); + } + + public IVariableV1 assign_lazy_load(Tensor value, string name = null) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index cd0902c6..ce587397 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -90,7 +90,7 @@ namespace Tensorflow value, use_locking: use_locking, name: name) : - @ref.assign(value, name: name) as Tensor; + @ref.assign_sub(value, name: name); //"""Update 'ref' by adding 'value' to it. // diff --git a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs index 3b0b9013..a160e496 100644 --- a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs @@ -209,23 +209,23 @@ namespace Tensorflow.Keras.Layers return output; } - Tensor _assign_new_value(IVariableV1 variable, Tensor value) + void _assign_new_value(IVariableV1 variable, Tensor value) { - return tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope => + tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope => { // var cm = ops.colocate_with(variable); - return state_ops.assign_sub(variable, value, name: scope); + variable.assign_lazy_load(value, name: scope); }); } - Tensor _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum) + void _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum) { - return tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope => + tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope => { // var cm = ops.colocate_with(variable); var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay"); var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay; - return state_ops.assign_sub(variable, update_delta, name: scope); + variable.assign_sub_lazy_load(update_delta, name: scope); }); } }