diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index e4e01c77..06f51352 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -46,7 +46,8 @@ namespace Tensorflow.Train var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); var m = get_slot(var, "m"); var m_scaled_g_values = grad * (1 - beta1_t); - var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking); + var mul = m * beta1_t; + var m_t = state_ops.assign(m, mul, use_locking: _use_locking); with(ops.control_dependencies(new[] { m_t }), delegate { m_t = scatter_add(m, indices, m_scaled_g_values); @@ -88,9 +89,15 @@ namespace Tensorflow.Train public override void _prepare() { - //copied from GradientDescentOptimizer - LearningRate = _call_if_callable(LearningRate); - LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate"); + var lr = _call_if_callable(_lr); + var beta1 = _call_if_callable(_beta1); + var beta2 = _call_if_callable(_beta2); + var epsilon = _call_if_callable(_epsilon); + + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + _beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); + _beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); + _epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); } } } diff --git a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs index f545d859..d69228d6 100644 --- a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs @@ -26,14 +26,13 @@ namespace Tensorflow.Train public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { - LearningRate = learning_rate; - LearningRateTensor = null; + _lr = learning_rate; } public override void _prepare() { - LearningRate = _call_if_callable(LearningRate); - LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate"); + var lr = _call_if_callable(_lr); + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); } } } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index d5c1f5c1..c7a31b9d 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -23,8 +23,10 @@ namespace Tensorflow string _name; public string Name => _name; - public float LearningRate { get; set; } - public Tensor LearningRateTensor { get; set; } + protected float _lr; + public float LearningRate => _lr; + protected Tensor _lr_t; + public Tensor LearningRateTensor => _lr_t; public bool _use_locking; public Dictionary> _slots; public Dictionary _non_slot_dict; @@ -38,7 +40,7 @@ namespace Tensorflow _name = name; _use_locking = use_locking; - LearningRate = learning_rate; + _lr = learning_rate; // Dictionary of slots. _slots = new Dictionary>(); _non_slot_dict = new Dictionary(); @@ -302,7 +304,7 @@ namespace Tensorflow protected RefVariable _get_non_slot_variable(string name, Graph graph = null) { - var key = $"{graph.graph_key}.{name}"; + var key = $"{name}.{graph.graph_key}"; var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; return non_slot; diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 4022e1dc..22894fe0 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -36,8 +36,8 @@ namespace Tensorflow validate_shape: validate_shape, use_locking: use_locking, name: name); - else - throw new NotImplementedException("state_ops.assign"); + throw new NotImplementedException("state_ops.assign"); + //return @ref.assign(value, name: name); } public static Tensor assign_sub(RefVariable @ref,