You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LinalgTest.cs 3.7 kB

4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.ManagedAPI
  5. {
  6. [TestClass]
  7. public class LinalgTest : EagerModeTestBase
  8. {
  9. [TestMethod]
  10. public void EyeTest()
  11. {
  12. var tensor = tf.linalg.eye(3);
  13. Assert.AreEqual(tensor.shape, (3, 3));
  14. Assert.AreEqual(0.0f, (double)tensor[2, 0]);
  15. Assert.AreEqual(0.0f, (double)tensor[2, 1]);
  16. Assert.AreEqual(1.0f, (double)tensor[2, 2]);
  17. }
  18. /// <summary>
  19. /// https://colab.research.google.com/github/biswajitsahoo1111/blog_notebooks/blob/master/Doing_Linear_Algebra_using_Tensorflow_2.ipynb#scrollTo=6xfOcTFBL3Up
  20. /// </summary>
  21. [TestMethod]
  22. public void LSTSQ()
  23. {
  24. var A_over = tf.constant(new float[,] { { 1, 2 }, { 2, 0.5f }, { 3, 1 }, { 4, 5.0f} });
  25. var A_under = tf.constant(new float[,] { { 3, 1, 2, 5 }, { 7, 9, 1, 4.0f } });
  26. var b_over = tf.constant(new float[] { 3, 4, 5, 6.0f }, shape: (4, 1));
  27. var b_under = tf.constant(new float[] { 7.2f, -5.8f }, shape: (2, 1));
  28. var x_over = tf.linalg.lstsq(A_over, b_over);
  29. var x = tf.matmul(tf.linalg.inv(tf.matmul(A_over, A_over, transpose_a: true)), tf.matmul(A_over, b_over, transpose_a: true));
  30. Assert.AreEqual(x_over.shape, (2, 1));
  31. AssetSequenceEqual(x_over.ToArray<float>(), x.ToArray<float>());
  32. var x_under = tf.linalg.lstsq(A_under, b_under);
  33. var y = tf.matmul(A_under, tf.matmul(tf.linalg.inv(tf.matmul(A_under, A_under, transpose_b: true)), b_under), transpose_a: true);
  34. Assert.AreEqual(x_under.shape, (4, 1));
  35. AssetSequenceEqual(x_under.ToArray<float>(), y.ToArray<float>());
  36. /*var x_over_reg = tf.linalg.lstsq(A_over, b_over, l2_regularizer: 2.0f);
  37. var x_under_reg = tf.linalg.lstsq(A_under, b_under, l2_regularizer: 2.0f);
  38. Assert.AreEqual(x_under_reg.shape, (4, 1));
  39. AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/
  40. }
  41. [TestMethod]
  42. public void Einsum()
  43. {
  44. var m0 = tf.random.normal((2, 3));
  45. var m1 = tf.random.normal((3, 5));
  46. var e = tf.linalg.einsum("ij,jk->ik", (m0, m1));
  47. Assert.AreEqual(e.shape, (2, 5));
  48. }
  49. [TestMethod]
  50. public void GlobalNorm()
  51. {
  52. var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 }));
  53. var norm = tf.linalg.global_norm(t_list);
  54. Assert.AreEqual(norm.numpy(), 14.282857f);
  55. }
  56. [TestMethod]
  57. public void Tensordot()
  58. {
  59. var a = tf.constant(new[] { 1, 2 });
  60. var b = tf.constant(new[] { 2, 3 });
  61. var c = tf.linalg.tensordot(a, b, 0);
  62. Assert.AreEqual(c.shape, (2, 2));
  63. AssetSequenceEqual(c.ToArray<int>(), new[] { 2, 3, 4, 6 });
  64. c = tf.linalg.tensordot(a, b, new[] { 0, 0 });
  65. Assert.AreEqual(c.shape.ndim, 0);
  66. Assert.AreEqual(c.numpy(), 8);
  67. }
  68. [TestMethod]
  69. public void Matmul()
  70. {
  71. var a = tf.constant(new[] { 1, 2, 3, 4, 5, 6 }, shape: (2, 3));
  72. var b = tf.constant(new[] { 7, 8, 9, 10, 11, 12 }, shape: (3, 2));
  73. var c = tf.linalg.matmul(a, b);
  74. Assert.AreEqual(c.shape, (2, 2));
  75. AssetSequenceEqual(c.ToArray<int>(), new[] { 58, 64, 139, 154 });
  76. }
  77. }
  78. }