You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LinalgTest.cs 2.2 kB

4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using static Tensorflow.Binding;
  3. namespace TensorFlowNET.UnitTest.ManagedAPI
  4. {
  5. [TestClass]
  6. public class LinalgTest : EagerModeTestBase
  7. {
  8. [TestMethod]
  9. public void EyeTest()
  10. {
  11. var tensor = tf.linalg.eye(3);
  12. Assert.AreEqual(tensor.shape, (3, 3));
  13. Assert.AreEqual(0.0f, (double)tensor[2, 0]);
  14. Assert.AreEqual(0.0f, (double)tensor[2, 1]);
  15. Assert.AreEqual(1.0f, (double)tensor[2, 2]);
  16. }
  17. /// <summary>
  18. /// https://colab.research.google.com/github/biswajitsahoo1111/blog_notebooks/blob/master/Doing_Linear_Algebra_using_Tensorflow_2.ipynb#scrollTo=6xfOcTFBL3Up
  19. /// </summary>
  20. [TestMethod]
  21. public void LSTSQ()
  22. {
  23. var A_over = tf.constant(new float[,] { { 1, 2 }, { 2, 0.5f }, { 3, 1 }, { 4, 5.0f} });
  24. var A_under = tf.constant(new float[,] { { 3, 1, 2, 5 }, { 7, 9, 1, 4.0f } });
  25. var b_over = tf.constant(new float[] { 3, 4, 5, 6.0f }, shape: (4, 1));
  26. var b_under = tf.constant(new float[] { 7.2f, -5.8f }, shape: (2, 1));
  27. var x_over = tf.linalg.lstsq(A_over, b_over);
  28. var x = tf.matmul(tf.linalg.inv(tf.matmul(A_over, A_over, transpose_a: true)), tf.matmul(A_over, b_over, transpose_a: true));
  29. Assert.AreEqual(x_over.shape, (2, 1));
  30. AssetSequenceEqual(x_over.ToArray<float>(), x.ToArray<float>());
  31. var x_under = tf.linalg.lstsq(A_under, b_under);
  32. var y = tf.matmul(A_under, tf.matmul(tf.linalg.inv(tf.matmul(A_under, A_under, transpose_b: true)), b_under), transpose_a: true);
  33. Assert.AreEqual(x_under.shape, (4, 1));
  34. AssetSequenceEqual(x_under.ToArray<float>(), y.ToArray<float>());
  35. /*var x_over_reg = tf.linalg.lstsq(A_over, b_over, l2_regularizer: 2.0f);
  36. var x_under_reg = tf.linalg.lstsq(A_under, b_under, l2_regularizer: 2.0f);
  37. Assert.AreEqual(x_under_reg.shape, (4, 1));
  38. AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/
  39. }
  40. }
  41. }