You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 3.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class ConstantTest
  12. {
  13. Status status = new Status();
  14. [TestMethod]
  15. public void ScalarConst()
  16. {
  17. var tensor1 = tf.constant(8); // int
  18. var tensor2 = tf.constant(6.0f); // float
  19. var tensor3 = tf.constant(6.0); // double
  20. }
  21. [TestMethod]
  22. public void StringConst()
  23. {
  24. string str = "Hello, TensorFlow.NET!";
  25. var tensor = tf.constant(str);
  26. Python.with<Session>(tf.Session(), sess =>
  27. {
  28. var result = sess.run(tensor);
  29. Assert.IsTrue(result.Data<string>()[0] == str);
  30. });
  31. }
  32. [TestMethod]
  33. public void ZerosConst()
  34. {
  35. // small size
  36. var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small");
  37. Python.with<Session>(tf.Session(), sess =>
  38. {
  39. var result = sess.run(tensor);
  40. Assert.AreEqual(result.shape[0], 3);
  41. Assert.AreEqual(result.shape[1], 2);
  42. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>()));
  43. });
  44. // big size
  45. tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big");
  46. Python.with<Session>(tf.Session(), sess =>
  47. {
  48. var result = sess.run(tensor);
  49. Assert.AreEqual(result.shape[0], 200);
  50. Assert.AreEqual(result.shape[1], 100);
  51. var data = result.Data<int>();
  52. Assert.AreEqual(0, data[0]);
  53. Assert.AreEqual(0, data[result.size - 1]);
  54. });
  55. }
  56. [TestMethod]
  57. public void NDimConst()
  58. {
  59. var nd = np.array(new int[][]
  60. {
  61. new int[]{ 3, 1, 1 },
  62. new int[]{ 2, 1, 3 }
  63. });
  64. var tensor = tf.constant(nd);
  65. Python.with<Session>(tf.Session(), sess =>
  66. {
  67. var result = sess.run(tensor);
  68. var data = result.Data<int>();
  69. Assert.AreEqual(result.shape[0], 2);
  70. Assert.AreEqual(result.shape[1], 3);
  71. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data));
  72. });
  73. }
  74. [TestMethod]
  75. public void Multiply()
  76. {
  77. var a = tf.constant(3.0);
  78. var b = tf.constant(2.0);
  79. var c = a * b;
  80. var sess = tf.Session();
  81. double result = sess.run(c);
  82. sess.close();
  83. Assert.AreEqual(6.0, result);
  84. }
  85. [TestMethod]
  86. public void StringEncode()
  87. {
  88. string str = "Hello, TensorFlow.NET!";
  89. ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
  90. Assert.AreEqual(dst_len, (ulong)23);
  91. string dst = "";
  92. c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status);
  93. Assert.AreEqual(status.Code, TF_Code.TF_OK);
  94. //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
  95. }
  96. }
  97. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。