Browse Source

add RefVariable override of state_ops.assign #271

tags/v0.9
Oceania2018 6 years ago
parent
commit
2a17b9c359
3 changed files with 33 additions and 2 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  3. +12
    -0
      src/TensorFlowNET.Core/Variables/state_ops.cs

+ 1
- 2
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -46,8 +46,7 @@ namespace Tensorflow.Train
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
var m = get_slot(var, "m");
var m_scaled_g_values = grad * (1 - beta1_t);
var mul = m * beta1_t;
var m_t = state_ops.assign(m, mul, use_locking: _use_locking);
var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
with(ops.control_dependencies(new[] { m_t }), delegate
{
m_t = scatter_add(m, indices, m_scaled_g_values);


+ 20
- 0
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -67,6 +67,26 @@ namespace Tensorflow
return _result[0];
}

public static Tensor assign(RefVariable @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });

var _result = _op.outputs;
var _inputs_flat = _op.inputs;

var _attrs = new Dictionary<string, object>();
_attrs["T"] = _op.get_attr("T");
_attrs["validate_shape"] = _op.get_attr("validate_shape");
_attrs["use_locking"] = _op.get_attr("use_locking");

_execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name);

return _result[0];
}

public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,


+ 12
- 0
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -40,6 +40,18 @@ namespace Tensorflow
//return @ref.assign(value, name: name);
}

public static Tensor assign(RefVariable @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
}

public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,


Loading…
Cancel
Save