You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TfLiteTest.cs 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow.Lite;
  9. namespace Tensorflow.Native.UnitTest
  10. {
  11. [TestClass]
  12. public class TfLiteTest
  13. {
  14. [TestMethod]
  15. public void TfLiteVersion()
  16. {
  17. var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion());
  18. Assert.IsNotNull(ver);
  19. }
  20. [TestMethod]
  21. public unsafe void SmokeTest()
  22. {
  23. var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin");
  24. var options = c_api_lite.TfLiteInterpreterOptionsCreate();
  25. c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2);
  26. var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options);
  27. c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle());
  28. c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());
  29. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
  30. Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter));
  31. Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter));
  32. var input_dims = new int[] { 2 };
  33. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length));
  34. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
  35. var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
  36. Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(input_tensor));
  37. Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
  38. Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
  39. Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor));
  40. Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor));
  41. Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor)));
  42. var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
  43. Assert.AreEqual(0f, input_params.scale);
  44. Assert.AreEqual(0, input_params.zero_point);
  45. var input = new[] { 1f, 3f };
  46. fixed (float* addr = &input[0])
  47. {
  48. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  49. c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float)));
  50. }
  51. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
  52. var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
  53. Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(output_tensor));
  54. Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor));
  55. Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0));
  56. Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor));
  57. Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor));
  58. Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor)));
  59. var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
  60. Assert.AreEqual(0f, output_params.scale);
  61. Assert.AreEqual(0, output_params.zero_point);
  62. var output = new float[2];
  63. fixed (float* addr = &output[0])
  64. {
  65. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  66. c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float)));
  67. }
  68. Assert.AreEqual(3f, output[0]);
  69. Assert.AreEqual(9f, output[1]);
  70. c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
  71. }
  72. [TestMethod]
  73. public unsafe void QuantizationParamsTest()
  74. {
  75. var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add_quantized.bin");
  76. var interpreter = c_api_lite.TfLiteInterpreterCreate(model, new SafeTfLiteInterpreterOptionsHandle(IntPtr.Zero));
  77. c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());
  78. var input_dims = new[] { 2 };
  79. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, 1));
  80. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
  81. var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
  82. Assert.IsNotNull(input_tensor);
  83. Assert.AreEqual(TfLiteDataType.kTfLiteUInt8, c_api_lite.TfLiteTensorType(input_tensor));
  84. Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
  85. Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
  86. var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
  87. Assert.AreEqual((0.003922f, 0), (input_params.scale, input_params.zero_point));
  88. var input = new byte[] { 1, 3 };
  89. fixed (byte* addr = &input[0])
  90. {
  91. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  92. c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(byte)));
  93. }
  94. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
  95. var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
  96. Assert.IsNotNull(output_tensor);
  97. var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
  98. Assert.AreEqual((0.003922f, 0), (output_params.scale, output_params.zero_point));
  99. var output = new byte[2];
  100. fixed (byte* addr = &output[0])
  101. {
  102. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  103. c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(byte)));
  104. }
  105. Assert.AreEqual(3f, output[0]);
  106. Assert.AreEqual(9f, output[1]);
  107. var dequantizedOutput0 = output_params.scale * (output[0] - output_params.zero_point);
  108. var dequantizedOutput1 = output_params.scale * (output[1] - output_params.zero_point);
  109. Assert.AreEqual(dequantizedOutput0, 0.011766f);
  110. Assert.AreEqual(dequantizedOutput1, 0.035298f);
  111. c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
  112. }
  113. }
  114. }