|
|
@@ -40,7 +40,7 @@ namespace Tensorflow.Native.UnitTest |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); |
|
|
|
|
|
|
|
var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); |
|
|
|
Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(input_tensor)); |
|
|
|
Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(input_tensor)); |
|
|
|
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); |
|
|
|
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); |
|
|
|
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor)); |
|
|
@@ -61,7 +61,7 @@ namespace Tensorflow.Native.UnitTest |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); |
|
|
|
|
|
|
|
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); |
|
|
|
Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(output_tensor)); |
|
|
|
Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(output_tensor)); |
|
|
|
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor)); |
|
|
|
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0)); |
|
|
|
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor)); |
|
|
@@ -83,5 +83,56 @@ namespace Tensorflow.Native.UnitTest |
|
|
|
|
|
|
|
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); |
|
|
|
} |
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public unsafe void QuantizationParamsTest() |
|
|
|
{ |
|
|
|
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add_quantized.bin"); |
|
|
|
var interpreter = c_api_lite.TfLiteInterpreterCreate(model, new SafeTfLiteInterpreterOptionsHandle(IntPtr.Zero)); |
|
|
|
c_api_lite.TfLiteModelDelete(model.DangerousGetHandle()); |
|
|
|
var input_dims = new[] { 2 }; |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, 1)); |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); |
|
|
|
|
|
|
|
var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); |
|
|
|
Assert.IsNotNull(input_tensor); |
|
|
|
|
|
|
|
Assert.AreEqual(TfLiteDataType.kTfLiteUInt8, c_api_lite.TfLiteTensorType(input_tensor)); |
|
|
|
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); |
|
|
|
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); |
|
|
|
|
|
|
|
var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor); |
|
|
|
Assert.AreEqual((0.003922f, 0), (input_params.scale, input_params.zero_point)); |
|
|
|
|
|
|
|
var input = new byte[] { 1, 3 }; |
|
|
|
fixed (byte* addr = &input[0]) |
|
|
|
{ |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, |
|
|
|
c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(byte))); |
|
|
|
} |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); |
|
|
|
|
|
|
|
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); |
|
|
|
Assert.IsNotNull(output_tensor); |
|
|
|
|
|
|
|
var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor); |
|
|
|
Assert.AreEqual((0.003922f, 0), (output_params.scale, output_params.zero_point)); |
|
|
|
|
|
|
|
var output = new byte[2]; |
|
|
|
fixed (byte* addr = &output[0]) |
|
|
|
{ |
|
|
|
Assert.AreEqual(TfLiteStatus.kTfLiteOk, |
|
|
|
c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(byte))); |
|
|
|
} |
|
|
|
Assert.AreEqual(3f, output[0]); |
|
|
|
Assert.AreEqual(9f, output[1]); |
|
|
|
|
|
|
|
var dequantizedOutput0 = output_params.scale * (output[0] - output_params.zero_point); |
|
|
|
var dequantizedOutput1 = output_params.scale * (output[1] - output_params.zero_point); |
|
|
|
Assert.AreEqual(dequantizedOutput0, 0.011766f); |
|
|
|
Assert.AreEqual(dequantizedOutput1, 0.035298f); |
|
|
|
|
|
|
|
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |