diff --git a/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs b/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs index 1283ecaf..a37f2892 100644 --- a/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs +++ b/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs @@ -15,6 +15,8 @@ namespace TensorFlowNET.UnitTest.Training [TestMethod] public void LinearRegression() { + tf.Graph().as_default(); + // Initialize the weights to `5.0` and the bias to `0.0` // In practice, these should be initialized to random values (for example, with `tf.random.normal`) var W = tf.Variable(5.0f); diff --git a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs index 98738528..1632f1e7 100644 --- a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs +++ b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs @@ -1,8 +1,5 @@ -using Microsoft.VisualStudio.TestPlatform.Utilities; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using System.Diagnostics; -using System.Linq; using Tensorflow.NumPy; using TensorFlowNET.UnitTest; using static Tensorflow.Binding; @@ -27,8 +24,8 @@ namespace Tensorflow.Keras.UnitTest.Optimizers var dtype = GetTypeForNumericType(); // train.GradientDescentOptimizer is V1 only API. - //tf.Graph().as_default(); - /*using (var sess = self.cached_session()) + tf.Graph().as_default(); + using (var sess = self.cached_session()) { var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); @@ -59,7 +56,7 @@ namespace Tensorflow.Keras.UnitTest.Optimizers new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate(var1)); // TODO: self.assertEqual(0, len(optimizer.variables())); - }*/ + } } [TestMethod] @@ -67,7 +64,7 @@ namespace Tensorflow.Keras.UnitTest.Optimizers { //TODO: add np.half TestBasic(); - // TestBasic(); + TestBasic(); } private void TestTensorLearningRate() where T : struct @@ -115,8 +112,8 @@ namespace Tensorflow.Keras.UnitTest.Optimizers public void TestTensorLearningRate() { //TODO: add np.half - // TestTensorLearningRate(); - // TestTensorLearningRate(); + TestTensorLearningRate(); + TestTensorLearningRate(); } } }