Browse Source

1: change learning rate and lr_tensor.

2: override _prepare() for AdamOptimizer.
3: fix key name if _get_non_slot_variable.
tags/v0.9
Oceania2018 6 years ago
parent
commit
59672a4df7
4 changed files with 22 additions and 14 deletions
  1. +11
    -4
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs
  2. +3
    -4
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
  3. +6
    -4
      src/TensorFlowNET.Core/Train/Optimizer.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Variables/state_ops.cs

+ 11
- 4
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -46,7 +46,8 @@ namespace Tensorflow.Train
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
var m = get_slot(var, "m");
var m_scaled_g_values = grad * (1 - beta1_t);
var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
var mul = m * beta1_t;
var m_t = state_ops.assign(m, mul, use_locking: _use_locking);
with(ops.control_dependencies(new[] { m_t }), delegate
{
m_t = scatter_add(m, indices, m_scaled_g_values);
@@ -88,9 +89,15 @@ namespace Tensorflow.Train

public override void _prepare()
{
//copied from GradientDescentOptimizer
LearningRate = _call_if_callable(LearningRate);
LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate");
var lr = _call_if_callable(_lr);
var beta1 = _call_if_callable(_beta1);
var beta2 = _call_if_callable(_beta2);
var epsilon = _call_if_callable(_epsilon);

_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
_beta1_t = ops.convert_to_tensor(beta1, name: "beta1");
_beta2_t = ops.convert_to_tensor(beta2, name: "beta2");
_epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon");
}
}
}

+ 3
- 4
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -26,14 +26,13 @@ namespace Tensorflow.Train
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name)
{
LearningRate = learning_rate;
LearningRateTensor = null;
_lr = learning_rate;
}

public override void _prepare()
{
LearningRate = _call_if_callable(LearningRate);
LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate");
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
}
}
}

+ 6
- 4
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -23,8 +23,10 @@ namespace Tensorflow

string _name;
public string Name => _name;
public float LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; }
protected float _lr;
public float LearningRate => _lr;
protected Tensor _lr_t;
public Tensor LearningRateTensor => _lr_t;
public bool _use_locking;
public Dictionary<string, Dictionary<string, RefVariable>> _slots;
public Dictionary<string, RefVariable> _non_slot_dict;
@@ -38,7 +40,7 @@ namespace Tensorflow

_name = name;
_use_locking = use_locking;
LearningRate = learning_rate;
_lr = learning_rate;
// Dictionary of slots.
_slots = new Dictionary<string, Dictionary<string, RefVariable>>();
_non_slot_dict = new Dictionary<string, RefVariable>();
@@ -302,7 +304,7 @@ namespace Tensorflow

protected RefVariable _get_non_slot_variable(string name, Graph graph = null)
{
var key = $"{graph.graph_key}.{name}";
var key = $"{name}.{graph.graph_key}";
var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;

return non_slot;


+ 2
- 2
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -36,8 +36,8 @@ namespace Tensorflow
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
else
throw new NotImplementedException("state_ops.assign");
throw new NotImplementedException("state_ops.assign");
//return @ref.assign(value, name: name);
}

public static Tensor assign_sub(RefVariable @ref,


Loading…
Cancel
Save