You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TfLiteTest.cs 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow.Lite;
  9. namespace Tensorflow.Native.UnitTest
  10. {
  11. [TestClass]
  12. public class TfLiteTest
  13. {
  14. [TestMethod]
  15. [Ignore]
  16. public void TfLiteVersion()
  17. {
  18. var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion());
  19. Assert.IsNotNull(ver);
  20. }
  21. [TestMethod]
  22. [Ignore]
  23. public unsafe void SmokeTest()
  24. {
  25. var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin");
  26. var options = c_api_lite.TfLiteInterpreterOptionsCreate();
  27. c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2);
  28. var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options);
  29. c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle());
  30. c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());
  31. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
  32. Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter));
  33. Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter));
  34. var input_dims = new int[] { 2 };
  35. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length));
  36. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
  37. var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
  38. Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(input_tensor));
  39. Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
  40. Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
  41. Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor));
  42. Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor));
  43. Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor)));
  44. var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
  45. Assert.AreEqual(0f, input_params.scale);
  46. Assert.AreEqual(0, input_params.zero_point);
  47. var input = new[] { 1f, 3f };
  48. fixed (float* addr = &input[0])
  49. {
  50. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  51. c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float)));
  52. }
  53. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
  54. var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
  55. Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(output_tensor));
  56. Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor));
  57. Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0));
  58. Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor));
  59. Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor));
  60. Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor)));
  61. var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
  62. Assert.AreEqual(0f, output_params.scale);
  63. Assert.AreEqual(0, output_params.zero_point);
  64. var output = new float[2];
  65. fixed (float* addr = &output[0])
  66. {
  67. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  68. c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float)));
  69. }
  70. Assert.AreEqual(3f, output[0]);
  71. Assert.AreEqual(9f, output[1]);
  72. c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
  73. }
  74. [TestMethod]
  75. [Ignore]
  76. public unsafe void QuantizationParamsTest()
  77. {
  78. var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add_quantized.bin");
  79. var interpreter = c_api_lite.TfLiteInterpreterCreate(model, new SafeTfLiteInterpreterOptionsHandle(IntPtr.Zero));
  80. c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());
  81. var input_dims = new[] { 2 };
  82. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, 1));
  83. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
  84. var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
  85. Assert.IsNotNull(input_tensor);
  86. Assert.AreEqual(TfLiteDataType.kTfLiteUInt8, c_api_lite.TfLiteTensorType(input_tensor));
  87. Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
  88. Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
  89. var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
  90. Assert.AreEqual((0.003922f, 0), (input_params.scale, input_params.zero_point));
  91. var input = new byte[] { 1, 3 };
  92. fixed (byte* addr = &input[0])
  93. {
  94. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  95. c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(byte)));
  96. }
  97. Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
  98. var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
  99. Assert.IsNotNull(output_tensor);
  100. var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
  101. Assert.AreEqual((0.003922f, 0), (output_params.scale, output_params.zero_point));
  102. var output = new byte[2];
  103. fixed (byte* addr = &output[0])
  104. {
  105. Assert.AreEqual(TfLiteStatus.kTfLiteOk,
  106. c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(byte)));
  107. }
  108. Assert.AreEqual(3f, output[0]);
  109. Assert.AreEqual(9f, output[1]);
  110. var dequantizedOutput0 = output_params.scale * (output[0] - output_params.zero_point);
  111. var dequantizedOutput1 = output_params.scale * (output[1] - output_params.zero_point);
  112. Assert.AreEqual(dequantizedOutput0, 0.011766f);
  113. Assert.AreEqual(dequantizedOutput1, 0.035298f);
  114. c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
  115. }
  116. }
  117. }