You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 2.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class ConstantTest
  12. {
  13. [TestMethod]
  14. public void ScalarConst()
  15. {
  16. var tensor1 = tf.constant(8); // int
  17. var tensor2 = tf.constant(6.0f); // float
  18. var tensor3 = tf.constant(6.0); // double
  19. }
  20. [TestMethod]
  21. public void StringConst()
  22. {
  23. string str = "Hello, TensorFlow.NET!";
  24. var tensor = tf.constant(str);
  25. Python.with<Session>(tf.Session(), sess =>
  26. {
  27. var result = sess.run(tensor);
  28. Assert.IsTrue(result.Data<string>()[0] == str);
  29. });
  30. }
  31. [TestMethod]
  32. public void ZerosConst()
  33. {
  34. // small size
  35. var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small");
  36. Python.with<Session>(tf.Session(), sess =>
  37. {
  38. var result = sess.run(tensor);
  39. Assert.AreEqual(result.shape[0], 3);
  40. Assert.AreEqual(result.shape[1], 2);
  41. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>()));
  42. });
  43. // big size
  44. tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big");
  45. Python.with<Session>(tf.Session(), sess =>
  46. {
  47. var result = sess.run(tensor);
  48. Assert.AreEqual(result.shape[0], 200);
  49. Assert.AreEqual(result.shape[1], 100);
  50. var data = result.Data<int>();
  51. Assert.AreEqual(0, data[0]);
  52. Assert.AreEqual(0, data[result.size - 1]);
  53. });
  54. }
  55. [TestMethod]
  56. public void NDimConst()
  57. {
  58. var nd = np.array(new int[][]
  59. {
  60. new int[]{ 3, 1, 1 },
  61. new int[]{ 2, 1, 3 }
  62. });
  63. var tensor = tf.constant(nd);
  64. Python.with<Session>(tf.Session(), sess =>
  65. {
  66. var result = sess.run(tensor);
  67. var data = result.Data<int>();
  68. Assert.AreEqual(result.shape[0], 2);
  69. Assert.AreEqual(result.shape[1], 3);
  70. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data));
  71. });
  72. }
  73. [TestMethod]
  74. public void Multiply()
  75. {
  76. var a = tf.constant(3.0);
  77. var b = tf.constant(2.0);
  78. var c = a * b;
  79. var sess = tf.Session();
  80. double result = sess.run(c);
  81. sess.close();
  82. Assert.AreEqual(6.0, result);
  83. }
  84. }
  85. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。