Browse Source

learning rate test

pull/1184/head
Alexander 1 year ago
parent
commit
c906f46aad
1 changed files with 49 additions and 0 deletions
  1. +49
    -0
      test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

+ 49
- 0
test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.NumPy;
using TensorFlowNET.UnitTest;
@@ -69,5 +70,53 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
TestBasic<double>();
}

private void TestTensorLearningRate<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

// train.GradientDescentOptimizer is V1 only API.
tf.Graph().as_default();
using (var sess = self.cached_session())
{
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
var lrate = constant_op.constant(3.0);
var grads_and_vars = new[] {
Tuple.Create(grads0, var0 as IVariableV1),
Tuple.Create(grads1, var1 as IVariableV1)
};
var sgd_op = tf.train.GradientDescentOptimizer(lrate)
.apply_gradients(grads_and_vars);

var global_variables = tf.global_variables_initializer();
sess.run(global_variables);

var initialVar0 = sess.run(var0);
var initialVar1 = sess.run(var1);
// Fetch params to validate initial values
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd
sgd_op.run();
// Validate updated params
self.assertAllCloseAccordingToType(
new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(
new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
self.evaluate<T[]>(var1));
// TODO: self.assertEqual(0, len(optimizer.variables()));
}
}

[TestMethod]
public void TestTensorLearningRate()
{
//TODO: add np.half
TestTensorLearningRate<float>();
TestTensorLearningRate<double>();
}
}
}

Loading…
Cancel
Save