Browse Source

feat: update quantize native params.

tags/v0.4.1-preview
Yaohui Liu 2 years ago
parent
commit
9850417a12
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 87 additions and 15 deletions
  1. +13
    -4
      LLama/LLamaQuantizer.cs
  2. +31
    -10
      LLama/Native/LLamaContextParams.cs
  3. +9
    -0
      LLama/Native/LLamaFtype.cs
  4. +29
    -0
      LLama/Native/LLamaModelQuantizeParams.cs
  5. +1
    -1
      LLama/Native/NativeApi.Quantize.cs
  6. +4
    -0
      LLama/Native/NativeApi.cs

+ 13
- 4
LLama/LLamaQuantizer.cs View File

@@ -20,14 +20,22 @@ namespace LLama
/// <param name="nthread">Thread to be used during the quantization. By default it's the physical core number.</param> /// <param name="nthread">Thread to be used during the quantization. By default it's the physical core number.</param>
/// <returns>Whether the quantization is successful.</returns> /// <returns>Whether the quantization is successful.</returns>
/// <exception cref="ArgumentException"></exception> /// <exception cref="ArgumentException"></exception>
public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1)
public static unsafe bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true,
bool quantizeOutputTensor = false)
{ {
if (!ValidateFtype(ftype)) if (!ValidateFtype(ftype))
{ {
throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " + throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " +
$"to perform quantization."); $"to perform quantization.");
} }
return NativeApi.llama_model_quantize(srcFileName, dstFilename, ftype, nthread) == 0;

var quantizeParams = NativeApi.llama_model_quantize_default_params();
quantizeParams.ftype = ftype;
quantizeParams.nthread = nthread;
quantizeParams.allow_requantize = allowRequantize;
quantizeParams.quantize_output_tensor = quantizeOutputTensor;
LLamaModelQuantizeParams* p = &quantizeParams;
return NativeApi.llama_model_quantize(srcFileName, dstFilename, p) == 0;
} }


/// <summary> /// <summary>
@@ -39,9 +47,10 @@ namespace LLama
/// <param name="nthread">Thread to be used during the quantization. By default it's the physical core number.</param> /// <param name="nthread">Thread to be used during the quantization. By default it's the physical core number.</param>
/// <returns>Whether the quantization is successful.</returns> /// <returns>Whether the quantization is successful.</returns>
/// <exception cref="ArgumentException"></exception> /// <exception cref="ArgumentException"></exception>
public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1)
public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1, bool allowRequantize = true,
bool quantizeOutputTensor = false)
{ {
return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread);
return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread, allowRequantize, quantizeOutputTensor);
} }


private static bool ValidateFtype(string ftype) private static bool ValidateFtype(string ftype)


+ 31
- 10
LLama/Native/LLamaContextParams.cs View File

@@ -9,19 +9,44 @@ namespace LLama.Native
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct LLamaContextParams public struct LLamaContextParams
{ {
/// <summary>
/// RNG seed, -1 for random
/// </summary>
public int seed;
/// <summary> /// <summary>
/// text context /// text context
/// </summary> /// </summary>
public int n_ctx; public int n_ctx;
/// <summary> /// <summary>
/// prompt processing batch size
/// </summary>
public int n_batch;
/// <summary>
/// number of layers to store in VRAM /// number of layers to store in VRAM
/// </summary> /// </summary>
public int n_gpu_layers; public int n_gpu_layers;
/// <summary> /// <summary>
/// RNG seed, -1 for random
/// the GPU that is used for scratch and small tensors
/// </summary> /// </summary>
public int seed;
public int main_gpu;
/// <summary>
/// how to split layers across multiple GPUs
/// </summary>
public TensorSplits tensor_split;
/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
public IntPtr progress_callback;
/// <summary>
/// context pointer passed to the progress callback
/// </summary>
public IntPtr progress_callback_user_data;


/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool low_vram;
/// <summary> /// <summary>
/// use fp16 for KV cache /// use fp16 for KV cache
/// </summary> /// </summary>
@@ -52,14 +77,10 @@ namespace LLama.Native
/// </summary> /// </summary>
[MarshalAs(UnmanagedType.I1)] [MarshalAs(UnmanagedType.I1)]
public bool embedding; public bool embedding;
}


/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
public IntPtr progress_callback;
/// <summary>
/// context pointer passed to the progress callback
/// </summary>
public IntPtr progress_callback_user_data;
public struct TensorSplits
{
public float Item1;
} }
} }

+ 9
- 0
LLama/Native/LLamaFtype.cs View File

@@ -16,5 +16,14 @@ namespace LLama.Native
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
} }
} }

+ 29
- 0
LLama/Native/LLamaModelQuantizeParams.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace LLama.Native
{
public struct LLamaModelQuantizeParams
{
/// <summary>
/// number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
/// </summary>
public int nthread;
/// <summary>
/// quantize to this llama_ftype
/// </summary>
public LLamaFtype ftype;
/// <summary>
/// allow quantizing non-f32/f16 tensors
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool allow_requantize;
/// <summary>
/// quantize output.weight
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool quantize_output_tensor;
}
}

+ 1
- 1
LLama/Native/NativeApi.Quantize.cs View File

@@ -17,6 +17,6 @@ namespace LLama.Native
/// <remarks>not great API - very likely to change</remarks> /// <remarks>not great API - very likely to change</remarks>
/// <returns>Returns 0 on success</returns> /// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread);
public unsafe static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param);
} }
} }

+ 4
- 0
LLama/Native/NativeApi.cs View File

@@ -10,6 +10,7 @@ namespace LLama.Native
using llama_token = Int32; using llama_token = Int32;
public unsafe partial class NativeApi public unsafe partial class NativeApi
{ {
public static readonly int LLAMA_MAX_DEVICES = 1;
static NativeApi() static NativeApi()
{ {
try try
@@ -34,6 +35,9 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaContextParams llama_context_default_params(); public static extern LLamaContextParams llama_context_default_params();


[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaModelQuantizeParams llama_model_quantize_default_params();

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mmap_supported(); public static extern bool llama_mmap_supported();




Loading…
Cancel
Save