@@ -9,7 +9,7 @@ | |||||
[](https://996.icu/#/en_US) | [](https://996.icu/#/en_US) | ||||
[](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) | [](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) | ||||
*master branch is based on tensorflow v2.4, v0.3x branch is based on tensorflow v2.3, v0.15-tensorflow1.15 is from tensorflow1.15.* | |||||
*master branch is based on tensorflow v2.x, v0.6x branch is based on tensorflow v2.6, v0.15-tensorflow1.15 is from tensorflow1.15.* | |||||
 |  | ||||
@@ -249,10 +249,6 @@ Follow us on [Twitter](https://twitter.com/ScisharpStack), [Facebook](https://ww | |||||
Join our chat on [Gitter](https://gitter.im/sci-sharp/community). | Join our chat on [Gitter](https://gitter.im/sci-sharp/community). | ||||
Scan QR code to join Tencent TIM group: | |||||
 | |||||
WeChat Sponsor 微信打赏: | WeChat Sponsor 微信打赏: | ||||
 |  | ||||
@@ -0,0 +1,91 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
using Tensorflow.Lite; | |||||
namespace Tensorflow | |||||
{ | |||||
public class c_api_lite | |||||
{ | |||||
public const string TensorFlowLibName = "tensorflowlite_c"; | |||||
public static string StringPiece(IntPtr handle) | |||||
{ | |||||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||||
} | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TfLiteVersion(); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern SafeTfLiteModelHandle TfLiteModelCreateFromFile(string model_path); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TfLiteModelDelete(IntPtr model); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern SafeTfLiteInterpreterOptionsHandle TfLiteInterpreterOptionsCreate(); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TfLiteInterpreterOptionsDelete(IntPtr options); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TfLiteInterpreterOptionsSetNumThreads(SafeTfLiteInterpreterOptionsHandle options, int num_threads); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern SafeTfLiteInterpreterHandle TfLiteInterpreterCreate(SafeTfLiteModelHandle model, SafeTfLiteInterpreterOptionsHandle optional_options); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TfLiteInterpreterDelete(IntPtr interpreter); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteStatus TfLiteInterpreterAllocateTensors(SafeTfLiteInterpreterHandle interpreter); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TfLiteInterpreterGetInputTensorCount(SafeTfLiteInterpreterHandle interpreter); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TfLiteInterpreterGetOutputTensorCount(SafeTfLiteInterpreterHandle interpreter); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteStatus TfLiteInterpreterResizeInputTensor(SafeTfLiteInterpreterHandle interpreter, | |||||
int input_index, int[] input_dims, int input_dims_size); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteTensor TfLiteInterpreterGetInputTensor(SafeTfLiteInterpreterHandle interpreter, int input_index); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteDataType TfLiteTensorType(TfLiteTensor tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TfLiteTensorNumDims(TfLiteTensor tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TfLiteTensorByteSize(TfLiteTensor tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TfLiteTensorData(TfLiteTensor tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TfLiteTensorName(TfLiteTensor tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteQuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteStatus TfLiteInterpreterInvoke(SafeTfLiteInterpreterHandle interpreter); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TfLiteInterpreterGetOutputTensor(SafeTfLiteInterpreterHandle interpreter, int output_index); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TfLiteStatus TfLiteTensorCopyToBuffer(TfLiteTensor output_tensor, IntPtr output_data, int output_data_size); | |||||
} | |||||
} |
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public class SafeTfLiteInterpreterHandle : SafeTensorflowHandle | |||||
{ | |||||
protected SafeTfLiteInterpreterHandle() | |||||
{ | |||||
} | |||||
public SafeTfLiteInterpreterHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api_lite.TfLiteInterpreterDelete(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public class SafeTfLiteInterpreterOptionsHandle : SafeTensorflowHandle | |||||
{ | |||||
protected SafeTfLiteInterpreterOptionsHandle() | |||||
{ | |||||
} | |||||
public SafeTfLiteInterpreterOptionsHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api_lite.TfLiteInterpreterOptionsDelete(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public class SafeTfLiteModelHandle : SafeTensorflowHandle | |||||
{ | |||||
protected SafeTfLiteModelHandle() | |||||
{ | |||||
} | |||||
public SafeTfLiteModelHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api_lite.TfLiteModelDelete(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public enum TfLiteDataType | |||||
{ | |||||
kTfLiteNoType = 0, | |||||
kTfLiteFloat32 = 1, | |||||
kTfLiteInt32 = 2, | |||||
kTfLiteUInt8 = 3, | |||||
kTfLiteInt64 = 4, | |||||
kTfLiteString = 5, | |||||
kTfLiteBool = 6, | |||||
kTfLiteInt16 = 7, | |||||
kTfLiteComplex64 = 8, | |||||
kTfLiteInt8 = 9, | |||||
kTfLiteFloat16 = 10, | |||||
kTfLiteFloat64 = 11, | |||||
kTfLiteComplex128 = 12, | |||||
kTfLiteUInt64 = 13, | |||||
kTfLiteResource = 14, | |||||
kTfLiteVariant = 15, | |||||
kTfLiteUInt32 = 16, | |||||
} | |||||
} |
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public struct TfLiteQuantizationParams | |||||
{ | |||||
public float scale; | |||||
public int zero_point; | |||||
} | |||||
} |
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public enum TfLiteStatus | |||||
{ | |||||
kTfLiteOk = 0, | |||||
// Generally referring to an error in the runtime (i.e. interpreter) | |||||
kTfLiteError = 1, | |||||
// Generally referring to an error from a TfLiteDelegate itself. | |||||
kTfLiteDelegateError = 2, | |||||
// Generally referring to an error in applying a delegate due to | |||||
// incompatibility between runtime and delegate, e.g., this error is returned | |||||
// when trying to apply a TfLite delegate onto a model graph that's already | |||||
// immutable. | |||||
kTfLiteApplicationError = 3, | |||||
// Generally referring to serialized delegate data not being found. | |||||
// See tflite::delegates::Serialization. | |||||
kTfLiteDelegateDataNotFound = 4, | |||||
// Generally referring to data-writing issues in delegate serialization. | |||||
// See tflite::delegates::Serialization. | |||||
kTfLiteDelegateDataWriteError = 5, | |||||
} | |||||
} |
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
namespace Tensorflow.Lite | |||||
{ | |||||
public struct TfLiteTensor | |||||
{ | |||||
IntPtr _handle; | |||||
public TfLiteTensor(IntPtr handle) | |||||
=> _handle = handle; | |||||
public static implicit operator TfLiteTensor(IntPtr handle) | |||||
=> new TfLiteTensor(handle); | |||||
public static implicit operator IntPtr(TfLiteTensor tensor) | |||||
=> tensor._handle; | |||||
public override string ToString() | |||||
=> $"TfLiteTensor 0x{_handle.ToString("x16")}"; | |||||
} | |||||
} |
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
<Version>0.60.3</Version> | |||||
<Version>0.60.4</Version> | |||||
<LangVersion>9.0</LangVersion> | <LangVersion>9.0</LangVersion> | ||||
<Nullable>enable</Nullable> | <Nullable>enable</Nullable> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
@@ -20,7 +20,7 @@ | |||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.60.3.0</AssemblyVersion> | |||||
<AssemblyVersion>0.60.4.0</AssemblyVersion> | |||||
<PackageReleaseNotes>tf.net 0.60.x and above are based on tensorflow native 2.6.0 | <PackageReleaseNotes>tf.net 0.60.x and above are based on tensorflow native 2.6.0 | ||||
* Eager Mode is added finally. | * Eager Mode is added finally. | ||||
@@ -35,7 +35,7 @@ Keras API is a separate package released as TensorFlow.Keras. | |||||
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. | tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. | ||||
tf.net 0.5x.x aligns with TensorFlow v2.5.x native library. | tf.net 0.5x.x aligns with TensorFlow v2.5.x native library. | ||||
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes> | tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes> | ||||
<FileVersion>0.60.3.0</FileVersion> | |||||
<FileVersion>0.60.4.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
@@ -189,7 +189,10 @@ namespace Tensorflow | |||||
{ | { | ||||
TF_DataType.TF_STRING => "string", | TF_DataType.TF_STRING => "string", | ||||
TF_DataType.TF_UINT8 => "uint8", | TF_DataType.TF_UINT8 => "uint8", | ||||
TF_DataType.TF_INT8 => "int8", | |||||
TF_DataType.TF_UINT32 => "uint32", | |||||
TF_DataType.TF_INT32 => "int32", | TF_DataType.TF_INT32 => "int32", | ||||
TF_DataType.TF_UINT64 => "uint64", | |||||
TF_DataType.TF_INT64 => "int64", | TF_DataType.TF_INT64 => "int64", | ||||
TF_DataType.TF_FLOAT => "float32", | TF_DataType.TF_FLOAT => "float32", | ||||
TF_DataType.TF_DOUBLE => "float64", | TF_DataType.TF_DOUBLE => "float64", | ||||
@@ -204,9 +207,12 @@ namespace Tensorflow | |||||
{ | { | ||||
TF_DataType.TF_BOOL => sizeof(bool), | TF_DataType.TF_BOOL => sizeof(bool), | ||||
TF_DataType.TF_UINT8 => sizeof(byte), | TF_DataType.TF_UINT8 => sizeof(byte), | ||||
TF_DataType.TF_INT8 => sizeof(byte), | |||||
TF_DataType.TF_INT8 => sizeof(sbyte), | |||||
TF_DataType.TF_UINT16 => sizeof(ushort), | |||||
TF_DataType.TF_INT16 => sizeof(short), | TF_DataType.TF_INT16 => sizeof(short), | ||||
TF_DataType.TF_UINT32 => sizeof(uint), | |||||
TF_DataType.TF_INT32 => sizeof(int), | TF_DataType.TF_INT32 => sizeof(int), | ||||
TF_DataType.TF_UINT64 => sizeof(ulong), | |||||
TF_DataType.TF_INT64 => sizeof(long), | TF_DataType.TF_INT64 => sizeof(long), | ||||
TF_DataType.TF_FLOAT => sizeof(float), | TF_DataType.TF_FLOAT => sizeof(float), | ||||
TF_DataType.TF_DOUBLE => sizeof(double), | TF_DataType.TF_DOUBLE => sizeof(double), | ||||
@@ -11,7 +11,9 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow.Benchmark.Leak | namespace Tensorflow.Benchmark.Leak | ||||
{ | { | ||||
/// <summary> | |||||
/// https://github.com/SciSharp/TensorFlow.NET/issues/418 | |||||
/// </summary> | |||||
public class SavedModelCleanup | public class SavedModelCleanup | ||||
{ | { | ||||
[Benchmark] | [Benchmark] | ||||
@@ -36,7 +36,7 @@ | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="BenchmarkDotNet" Version="0.13.0" /> | |||||
<PackageReference Include="BenchmarkDotNet" Version="0.13.1" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -24,14 +24,14 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.10.0" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.5" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.5" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.0.3"> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.7" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.7" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.1.0"> | |||||
<PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
</PackageReference> | </PackageReference> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -14,14 +14,14 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0-release-20210626-04" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.5" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.5" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.0.3"> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.7" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.7" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.1.0"> | |||||
<PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
</PackageReference> | </PackageReference> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -0,0 +1,138 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using Tensorflow.Lite; | |||||
namespace Tensorflow.Native.UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class TfLiteTest | |||||
{ | |||||
[TestMethod] | |||||
public void TfLiteVersion() | |||||
{ | |||||
var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion()); | |||||
Assert.IsNotNull(ver); | |||||
} | |||||
[TestMethod] | |||||
public unsafe void SmokeTest() | |||||
{ | |||||
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin"); | |||||
var options = c_api_lite.TfLiteInterpreterOptionsCreate(); | |||||
c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2); | |||||
var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options); | |||||
c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle()); | |||||
c_api_lite.TfLiteModelDelete(model.DangerousGetHandle()); | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); | |||||
Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter)); | |||||
Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter)); | |||||
var input_dims = new int[] { 2 }; | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length)); | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); | |||||
var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); | |||||
Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(input_tensor)); | |||||
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); | |||||
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); | |||||
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor)); | |||||
Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor)); | |||||
Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor))); | |||||
var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor); | |||||
Assert.AreEqual(0f, input_params.scale); | |||||
Assert.AreEqual(0, input_params.zero_point); | |||||
var input = new[] { 1f, 3f }; | |||||
fixed (float* addr = &input[0]) | |||||
{ | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, | |||||
c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float))); | |||||
} | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); | |||||
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); | |||||
Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(output_tensor)); | |||||
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor)); | |||||
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0)); | |||||
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor)); | |||||
Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor)); | |||||
Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor))); | |||||
var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor); | |||||
Assert.AreEqual(0f, output_params.scale); | |||||
Assert.AreEqual(0, output_params.zero_point); | |||||
var output = new float[2]; | |||||
fixed (float* addr = &output[0]) | |||||
{ | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, | |||||
c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float))); | |||||
} | |||||
Assert.AreEqual(3f, output[0]); | |||||
Assert.AreEqual(9f, output[1]); | |||||
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); | |||||
} | |||||
[TestMethod] | |||||
public unsafe void QuantizationParamsTest() | |||||
{ | |||||
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add_quantized.bin"); | |||||
var interpreter = c_api_lite.TfLiteInterpreterCreate(model, new SafeTfLiteInterpreterOptionsHandle(IntPtr.Zero)); | |||||
c_api_lite.TfLiteModelDelete(model.DangerousGetHandle()); | |||||
var input_dims = new[] { 2 }; | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, 1)); | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); | |||||
var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); | |||||
Assert.IsNotNull(input_tensor); | |||||
Assert.AreEqual(TfLiteDataType.kTfLiteUInt8, c_api_lite.TfLiteTensorType(input_tensor)); | |||||
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); | |||||
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); | |||||
var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor); | |||||
Assert.AreEqual((0.003922f, 0), (input_params.scale, input_params.zero_point)); | |||||
var input = new byte[] { 1, 3 }; | |||||
fixed (byte* addr = &input[0]) | |||||
{ | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, | |||||
c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(byte))); | |||||
} | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); | |||||
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); | |||||
Assert.IsNotNull(output_tensor); | |||||
var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor); | |||||
Assert.AreEqual((0.003922f, 0), (output_params.scale, output_params.zero_point)); | |||||
var output = new byte[2]; | |||||
fixed (byte* addr = &output[0]) | |||||
{ | |||||
Assert.AreEqual(TfLiteStatus.kTfLiteOk, | |||||
c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(byte))); | |||||
} | |||||
Assert.AreEqual(3f, output[0]); | |||||
Assert.AreEqual(9f, output[1]); | |||||
var dequantizedOutput0 = output_params.scale * (output[0] - output_params.zero_point); | |||||
var dequantizedOutput1 = output_params.scale * (output[1] - output_params.zero_point); | |||||
Assert.AreEqual(dequantizedOutput0, 0.011766f); | |||||
Assert.AreEqual(dequantizedOutput1, 0.035298f); | |||||
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); | |||||
} | |||||
} | |||||
} |
@@ -24,16 +24,31 @@ | |||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | |||||
<None Remove="Lite\testdata\add.bin" /> | |||||
<None Remove="Lite\testdata\add_quantized.bin" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | |||||
<Content Include="Lite\testdata\add.bin"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</Content> | |||||
<Content Include="Lite\testdata\add_quantized.bin"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</Content> | |||||
</ItemGroup> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0-release-20210626-04" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.5" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.5" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.0.3"> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.7" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.7" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.1.0"> | |||||
<PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
</PackageReference> | </PackageReference> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist-Lite" Version="2.6.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -48,10 +48,10 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.139" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.139" /> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0-release-20210626-04" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.5" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.5" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.7" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.7" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||