Browse Source

Merge pull request #1161 from Wanglongzhi2001/master

fix: add the momentum parameter's implemention of SGD
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
48403a55e2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 4 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
  2. +4
    -0
      src/TensorFlowNET.Core/Training/gen_training_ops.cs
  3. +2
    -2
      src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs
  4. +18
    -1
      src/TensorFlowNET.Keras/Optimizers/SGD.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/IOptimizerApi.cs View File

@@ -63,6 +63,6 @@ namespace Tensorflow.Keras
bool centered = false, bool centered = false,
string name = "RMSprop"); string name = "RMSprop");


IOptimizer SGD(float learning_rate);
IOptimizer SGD(float learning_rate, float momentum);
} }
} }

+ 4
- 0
src/TensorFlowNET.Core/Training/gen_training_ops.cs View File

@@ -51,5 +51,9 @@ namespace Tensorflow
public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
=> tf.Context.ExecuteOp("ResourceApplyGradientDescent", name, => tf.Context.ExecuteOp("ResourceApplyGradientDescent", name,
new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking })); new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking }));

public static Tensor resource_apply_keras_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool use_locking = false, bool use_nesterov = false, string name = null)
=> tf.Context.ExecuteOp("ResourceApplyKerasMomentum", name,
new ExecuteOpArgs(var, accum, lr, grad, momentum).SetAttributes(new { use_locking, use_nesterov }));
} }
} }

+ 2
- 2
src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers
Name = name Name = name
}); });


public IOptimizer SGD(float learning_rate)
=> new SGD(learning_rate);
public IOptimizer SGD(float learning_rate, float momentum)
=> new SGD(learning_rate, momentum);
} }
} }

+ 18
- 1
src/TensorFlowNET.Keras/Optimizers/SGD.cs View File

@@ -22,6 +22,8 @@ namespace Tensorflow.Keras.Optimizers
_set_hyper("decay", decay); _set_hyper("decay", decay);


_momentum = momentum > 0; _momentum = momentum > 0;
if (momentum < 0 || momentum > 1)
throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}.");


_set_hyper("momentum", momentum); _set_hyper("momentum", momentum);


@@ -30,6 +32,13 @@ namespace Tensorflow.Keras.Optimizers
#pragma warning restore CS1717 // Assignment made to same variable #pragma warning restore CS1717 // Assignment made to same variable
} }


protected override void _create_slots(IVariableV1[] var_list)
{
if (_momentum)
foreach (var var in var_list)
add_slot(var, "momentum");
}

protected override void _prepare_local(DeviceDType device_dtype, protected override void _prepare_local(DeviceDType device_dtype,
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
{ {
@@ -43,7 +52,15 @@ namespace Tensorflow.Keras.Optimizers
{ {
if (_momentum) if (_momentum)
{ {
throw new NotImplementedException("_resource_apply_dense");
var momentum_var = get_slot(var, "momentum");
return gen_training_ops.resource_apply_keras_momentum(
var.Handle,
momentum_var.Handle,
_get_hyper("learning_rate", var.dtype),
grad,
_get_hyper("momentum", var.dtype),
use_locking: _use_locking,
use_nesterov: nesterov);
} }
var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype());




Loading…
Cancel
Save