| @@ -1,13 +1,7 @@ | |||||
| using System.Reflection.Metadata; | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| using System.Security.Cryptography; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using Microsoft.SemanticKernel; | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | using Microsoft.SemanticKernel.AI.ChatCompletion; | ||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| using LLamaSharp.SemanticKernel.ChatCompletion; | using LLamaSharp.SemanticKernel.ChatCompletion; | ||||
| using LLamaSharp.SemanticKernel.TextCompletion; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -22,7 +16,7 @@ namespace LLama.Examples.NewVersion | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue), | |||||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)), | |||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | using var model = LLamaWeights.LoadFromFile(parameters); | ||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| @@ -22,7 +22,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| var seed = 1337; | |||||
| var seed = 1337u; | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| @@ -21,7 +21,7 @@ namespace LLama.Examples.NewVersion | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue), | |||||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) | |||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | using var model = LLamaWeights.LoadFromFile(parameters); | ||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| @@ -15,7 +15,7 @@ namespace LLama.Examples.NewVersion | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var @params = new ModelParams(modelPath) | var @params = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue) | |||||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) | |||||
| }; | }; | ||||
| using var weights = LLamaWeights.LoadFromFile(@params); | using var weights = LLamaWeights.LoadFromFile(@params); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using LLama.Examples.NewVersion; | using LLama.Examples.NewVersion; | ||||
| using LLama.Native; | |||||
| Console.WriteLine("======================================================================================================"); | Console.WriteLine("======================================================================================================"); | ||||
| @@ -7,7 +8,7 @@ Console.WriteLine(" __ __ ____ _ | |||||
| Console.WriteLine("======================================================================================================"); | Console.WriteLine("======================================================================================================"); | ||||
| NativeApi.llama_empty_call(); | |||||
| Console.WriteLine(); | Console.WriteLine(); | ||||
| await NewVersionTestRunner.Run(); | await NewVersionTestRunner.Run(); | ||||
| @@ -27,7 +27,7 @@ namespace LLama.Unittest | |||||
| public void BasicModelProperties() | public void BasicModelProperties() | ||||
| { | { | ||||
| Assert.Equal(32000, _model.VocabCount); | Assert.Equal(32000, _model.VocabCount); | ||||
| Assert.Equal(2048, _model.ContextSize); | |||||
| Assert.Equal(4096, _model.ContextSize); | |||||
| Assert.Equal(4096, _model.EmbeddingSize); | Assert.Equal(4096, _model.EmbeddingSize); | ||||
| Assert.Equal(Encoding.UTF8, _model.Encoding); | Assert.Equal(Encoding.UTF8, _model.Encoding); | ||||
| } | } | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| { | { | ||||
| public class LLamaContextTests | |||||
| public sealed class LLamaContextTests | |||||
| : IDisposable | : IDisposable | ||||
| { | { | ||||
| private readonly LLamaWeights _weights; | private readonly LLamaWeights _weights; | ||||
| @@ -10,10 +10,7 @@ namespace LLama.Unittest | |||||
| public LLamaContextTests() | public LLamaContextTests() | ||||
| { | { | ||||
| var @params = new ModelParams(Constants.ModelPath) | |||||
| { | |||||
| ContextSize = 768, | |||||
| }; | |||||
| var @params = new ModelParams(Constants.ModelPath); | |||||
| _weights = LLamaWeights.LoadFromFile(@params); | _weights = LLamaWeights.LoadFromFile(@params); | ||||
| _context = _weights.CreateContext(@params); | _context = _weights.CreateContext(@params); | ||||
| } | } | ||||
| @@ -27,7 +24,7 @@ namespace LLama.Unittest | |||||
| [Fact] | [Fact] | ||||
| public void CheckProperties() | public void CheckProperties() | ||||
| { | { | ||||
| Assert.Equal(768, _context.ContextSize); | |||||
| Assert.Equal(4096, _context.ContextSize); | |||||
| Assert.Equal(4096, _context.EmbeddingSize); | Assert.Equal(4096, _context.EmbeddingSize); | ||||
| Assert.Equal(32000, _context.VocabCount); | Assert.Equal(32000, _context.VocabCount); | ||||
| Assert.Equal(0, _context.KVCacheTokenCount); | Assert.Equal(0, _context.KVCacheTokenCount); | ||||
| @@ -14,7 +14,7 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Model context size (n_ctx) | /// Model context size (n_ctx) | ||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize { get; set; } = 512; | |||||
| public uint ContextSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| /// </summary> | /// </summary> | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Seed for the random number generator (seed) | /// Seed for the random number generator (seed) | ||||
| /// </summary> | /// </summary> | ||||
| public int Seed { get; set; } = 1686349486; | |||||
| public uint Seed { get; set; } = 1686349486; | |||||
| /// <summary> | /// <summary> | ||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | /// Use f16 instead of f32 for memory kv (memory_f16) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -59,10 +59,13 @@ namespace LLama.Web.Common | |||||
| /// lora adapter path (lora_adapter) | /// lora adapter path (lora_adapter) | ||||
| /// </summary> | /// </summary> | ||||
| public string LoraAdapter { get; set; } = string.Empty; | public string LoraAdapter { get; set; } = string.Empty; | ||||
| /// <summary> | |||||
| /// base model path for the lora adapter (lora_base) | |||||
| /// </summary> | |||||
| public string LoraBase { get; set; } = string.Empty; | |||||
| public float LoraAdapterScale { get; set; } = 1; | |||||
| /// <summary> | |||||
| /// base model path for the lora adapter (lora_base) | |||||
| /// </summary> | |||||
| public string LoraBase { get; set; } = string.Empty; | |||||
| /// <summary> | /// <summary> | ||||
| /// Number of threads (-1 = autodetect) (n_threads) | /// Number of threads (-1 = autodetect) (n_threads) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -70,7 +73,7 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | ||||
| /// </summary> | /// </summary> | ||||
| public int BatchSize { get; set; } = 512; | |||||
| public uint BatchSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether to convert eos to newline during the inference. | /// Whether to convert eos to newline during the inference. | ||||
| @@ -107,5 +110,10 @@ namespace LLama.Web.Common | |||||
| /// The encoding to use for models | /// The encoding to use for models | ||||
| /// </summary> | /// </summary> | ||||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | public Encoding Encoding { get; set; } = Encoding.UTF8; | ||||
| /// <summary> | |||||
| /// Load vocab only (no weights) | |||||
| /// </summary> | |||||
| public bool VocabOnly { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| public int TopK { get; set; } | public int TopK { get; set; } | ||||
| /// <summary> | |||||
| /// <summary>llama_eval | |||||
| /// 1.0 = disabled | /// 1.0 = disabled | ||||
| /// </summary> | /// </summary> | ||||
| public float TopP { get; set; } | public float TopP { get; set; } | ||||
| @@ -10,7 +10,7 @@ namespace LLama.Abstractions | |||||
| /// <summary> | /// <summary> | ||||
| /// Model context size (n_ctx) | /// Model context size (n_ctx) | ||||
| /// </summary> | /// </summary> | ||||
| int ContextSize { get; set; } | |||||
| uint ContextSize { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Abstractions | |||||
| /// <summary> | /// <summary> | ||||
| /// Seed for the random number generator (seed) | /// Seed for the random number generator (seed) | ||||
| /// </summary> | /// </summary> | ||||
| int Seed { get; set; } | |||||
| uint Seed { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | /// Use f16 instead of f32 for memory kv (memory_f16) | ||||
| @@ -62,6 +62,8 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| string LoraAdapter { get; set; } | string LoraAdapter { get; set; } | ||||
| float LoraAdapterScale { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// base model path for the lora adapter (lora_base) | /// base model path for the lora adapter (lora_base) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -75,7 +77,7 @@ namespace LLama.Abstractions | |||||
| /// <summary> | /// <summary> | ||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | ||||
| /// </summary> | /// </summary> | ||||
| int BatchSize { get; set; } | |||||
| uint BatchSize { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether to use embedding mode. (embedding) Note that if this is set to true, | /// Whether to use embedding mode. (embedding) Note that if this is set to true, | ||||
| @@ -107,5 +109,10 @@ namespace LLama.Abstractions | |||||
| /// The encoding to use for models | /// The encoding to use for models | ||||
| /// </summary> | /// </summary> | ||||
| Encoding Encoding { get; set; } | Encoding Encoding { get; set; } | ||||
| /// <summary> | |||||
| /// Load vocab only (no weights) | |||||
| /// </summary> | |||||
| bool VocabOnly { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ namespace LLama.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Model context size (n_ctx) | /// Model context size (n_ctx) | ||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize { get; set; } = 512; | |||||
| public uint ContextSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| /// </summary> | /// </summary> | ||||
| @@ -31,7 +31,7 @@ namespace LLama.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Seed for the random number generator (seed) | /// Seed for the random number generator (seed) | ||||
| /// </summary> | /// </summary> | ||||
| public int Seed { get; set; } = 1686349486; | |||||
| public uint Seed { get; set; } = 1686349486; | |||||
| /// <summary> | /// <summary> | ||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | /// Use f16 instead of f32 for memory kv (memory_f16) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -56,6 +56,9 @@ namespace LLama.Common | |||||
| /// lora adapter path (lora_adapter) | /// lora adapter path (lora_adapter) | ||||
| /// </summary> | /// </summary> | ||||
| public string LoraAdapter { get; set; } = string.Empty; | public string LoraAdapter { get; set; } = string.Empty; | ||||
| public float LoraAdapterScale { get; set; } = 1; | |||||
| /// <summary> | /// <summary> | ||||
| /// base model path for the lora adapter (lora_base) | /// base model path for the lora adapter (lora_base) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -67,7 +70,7 @@ namespace LLama.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | ||||
| /// </summary> | /// </summary> | ||||
| public int BatchSize { get; set; } = 512; | |||||
| public uint BatchSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether to use embedding mode. (embedding) Note that if this is set to true, | /// Whether to use embedding mode. (embedding) Note that if this is set to true, | ||||
| @@ -95,6 +98,11 @@ namespace LLama.Common | |||||
| /// </summary> | /// </summary> | ||||
| public bool MulMatQ { get; set; } | public bool MulMatQ { get; set; } | ||||
| /// <summary> | |||||
| /// Load vocab only (no weights) | |||||
| /// </summary> | |||||
| public bool VocabOnly { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// The encoding to use to convert text for the model | /// The encoding to use to convert text for the model | ||||
| /// </summary> | /// </summary> | ||||
| @@ -138,10 +146,10 @@ namespace LLama.Common | |||||
| /// <param name="mulMatQ">Use experimental mul_mat_q kernels</param> | /// <param name="mulMatQ">Use experimental mul_mat_q kernels</param> | ||||
| /// <param name="encoding">The encoding to use to convert text for the model</param> | /// <param name="encoding">The encoding to use to convert text for the model</param> | ||||
| [Obsolete("Use object initializer to set all optional parameters")] | [Obsolete("Use object initializer to set all optional parameters")] | ||||
| public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, | |||||
| int seed = 1337, bool useFp16Memory = true, | |||||
| public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount = 20, | |||||
| uint seed = 1337, bool useFp16Memory = true, | |||||
| bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | ||||
| string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, | |||||
| string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512, | |||||
| bool embeddingMode = false, | bool embeddingMode = false, | ||||
| float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, | float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, | ||||
| string encoding = "UTF-8") | string encoding = "UTF-8") | ||||
| @@ -19,34 +19,45 @@ namespace LLama.Extensions | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="FileNotFoundException"></exception> | /// <exception cref="FileNotFoundException"></exception> | ||||
| /// <exception cref="ArgumentException"></exception> | /// <exception cref="ArgumentException"></exception> | ||||
| public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result) | |||||
| public static void ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result) | |||||
| { | { | ||||
| if (!File.Exists(@params.ModelPath)) | |||||
| throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); | |||||
| if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) | |||||
| throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); | |||||
| result = NativeApi.llama_context_default_params(); | result = NativeApi.llama_context_default_params(); | ||||
| result.n_ctx = @params.ContextSize; | result.n_ctx = @params.ContextSize; | ||||
| result.n_batch = @params.BatchSize; | result.n_batch = @params.BatchSize; | ||||
| result.main_gpu = @params.MainGpu; | |||||
| result.n_gpu_layers = @params.GpuLayerCount; | |||||
| result.seed = @params.Seed; | result.seed = @params.Seed; | ||||
| result.f16_kv = @params.UseFp16Memory; | result.f16_kv = @params.UseFp16Memory; | ||||
| result.use_mmap = @params.UseMemorymap; | |||||
| result.use_mlock = @params.UseMemoryLock; | |||||
| result.logits_all = @params.Perplexity; | result.logits_all = @params.Perplexity; | ||||
| result.embedding = @params.EmbeddingMode; | result.embedding = @params.EmbeddingMode; | ||||
| result.low_vram = @params.LowVram; | |||||
| result.rope_freq_base = @params.RopeFrequencyBase; | result.rope_freq_base = @params.RopeFrequencyBase; | ||||
| result.rope_freq_scale = @params.RopeFrequencyScale; | result.rope_freq_scale = @params.RopeFrequencyScale; | ||||
| result.mul_mat_q = @params.MulMatQ; | result.mul_mat_q = @params.MulMatQ; | ||||
| } | |||||
| /// <summary> | |||||
| /// Convert the given `IModelParams` into a `LLamaModelParams` | |||||
| /// </summary> | |||||
| /// <param name="params"></param> | |||||
| /// <param name="result"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="FileNotFoundException"></exception> | |||||
| /// <exception cref="ArgumentException"></exception> | |||||
| public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result) | |||||
| { | |||||
| if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) | |||||
| throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); | |||||
| result = NativeApi.llama_model_default_params(); | |||||
| result.main_gpu = @params.MainGpu; | |||||
| result.n_gpu_layers = @params.GpuLayerCount; | |||||
| result.use_mlock = @params.UseMemoryLock; | |||||
| result.use_mmap = @params.UseMemorymap; | |||||
| result.vocab_only = @params.VocabOnly; | |||||
| var pin = @params.TensorSplits.AsMemory().Pin(); | var pin = @params.TensorSplits.AsMemory().Pin(); | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| result.tensor_split = (nint)pin.Pointer; | |||||
| result.tensor_split = (float*)pin.Pointer; | |||||
| } | } | ||||
| return pin; | return pin; | ||||
| @@ -105,7 +105,7 @@ namespace LLama | |||||
| _logger = logger; | _logger = logger; | ||||
| _encoding = @params.Encoding; | _encoding = @params.Encoding; | ||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||||
| @params.ToLlamaContextParams(out var lparams); | |||||
| _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | ||||
| } | } | ||||
| @@ -115,9 +115,9 @@ namespace LLama | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public LLamaContext Clone() | public LLamaContext Clone() | ||||
| { | { | ||||
| using var pin = Params.ToLlamaContextParams(out var lparams); | |||||
| Params.ToLlamaContextParams(out var lparams); | |||||
| var clone = _ctx.Clone(lparams); | var clone = _ctx.Clone(lparams); | ||||
| return new LLamaContext(clone, Params); | |||||
| return new LLamaContext(clone, Params); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -177,19 +177,6 @@ namespace LLama | |||||
| fileStream.SetLength(writtenBytes); | fileStream.SetLength(writtenBytes); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Get the state data as a byte array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")] | |||||
| public byte[] GetStateData() | |||||
| { | |||||
| var stateSize = NativeApi.llama_get_state_size(_ctx); | |||||
| byte[] stateMemory = new byte[stateSize]; | |||||
| NativeApi.llama_copy_state_data(_ctx, stateMemory); | |||||
| return stateMemory; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the state data as an opaque handle | /// Get the state data as an opaque handle | ||||
| /// </summary> | /// </summary> | ||||
| @@ -198,31 +185,28 @@ namespace LLama | |||||
| { | { | ||||
| var stateSize = _ctx.GetStateSize(); | var stateSize = _ctx.GetStateSize(); | ||||
| unsafe | |||||
| // Allocate a chunk of memory large enough to hold the entire state | |||||
| var memory = Marshal.AllocHGlobal((nint)stateSize); | |||||
| try | |||||
| { | { | ||||
| // Allocate a chunk of memory large enough to hold the entire state | |||||
| var memory = Marshal.AllocHGlobal((nint)stateSize); | |||||
| try | |||||
| { | |||||
| // Copy the state data into memory, discover the actual size required | |||||
| var actualSize = _ctx.GetState(memory, stateSize); | |||||
| // Copy the state data into memory, discover the actual size required | |||||
| var actualSize = _ctx.GetState(memory, stateSize); | |||||
| // Shrink to size | |||||
| memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); | |||||
| // Shrink to size | |||||
| memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); | |||||
| // Wrap memory in a "state" | |||||
| var state = new State(memory); | |||||
| // Wrap memory in a "state" | |||||
| var state = new State(memory); | |||||
| // Set memory to zero, to prevent it being freed in finally block | |||||
| memory = IntPtr.Zero; | |||||
| // Set memory to zero, to prevent it being freed in finally block | |||||
| memory = IntPtr.Zero; | |||||
| return state; | |||||
| } | |||||
| finally | |||||
| { | |||||
| if (memory != IntPtr.Zero) | |||||
| Marshal.FreeHGlobal(memory); | |||||
| } | |||||
| return state; | |||||
| } | |||||
| finally | |||||
| { | |||||
| if (memory != IntPtr.Zero) | |||||
| Marshal.FreeHGlobal(memory); | |||||
| } | } | ||||
| } | } | ||||
| @@ -247,21 +231,6 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Load the state from memory. | |||||
| /// </summary> | |||||
| /// <param name="stateData"></param> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public void LoadState(byte[] stateData) | |||||
| { | |||||
| int stateSize = (int)NativeApi.llama_get_state_size(_ctx); | |||||
| if (stateData.Length > stateSize) | |||||
| { | |||||
| throw new RuntimeError("Failed to validate state size."); | |||||
| } | |||||
| NativeApi.llama_set_state_data(_ctx, stateData); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Load the state from memory. | /// Load the state from memory. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -463,15 +432,15 @@ namespace LLama | |||||
| public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount) | public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount) | ||||
| { | { | ||||
| var total = tokens.Length; | var total = tokens.Length; | ||||
| for(var i = 0; i < total; i += Params.BatchSize) | |||||
| for(var i = 0; i < total; i += (int)Params.BatchSize) | |||||
| { | { | ||||
| var n_eval = total - i; | var n_eval = total - i; | ||||
| if (n_eval > Params.BatchSize) | if (n_eval > Params.BatchSize) | ||||
| { | { | ||||
| n_eval = Params.BatchSize; | |||||
| n_eval = (int)Params.BatchSize; | |||||
| } | } | ||||
| if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) | |||||
| if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount)) | |||||
| { | { | ||||
| _logger?.LogError($"[LLamaContext] Failed to eval."); | _logger?.LogError($"[LLamaContext] Failed to eval."); | ||||
| throw new RuntimeError("Failed to eval."); | throw new RuntimeError("Failed to eval."); | ||||
| @@ -35,6 +35,16 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize => NativeHandle.ContextSize; | public int ContextSize => NativeHandle.ContextSize; | ||||
| /// <summary> | |||||
| /// Get the size of this model in bytes | |||||
| /// </summary> | |||||
| public ulong SizeInBytes => NativeHandle.SizeInBytes; | |||||
| /// <summary> | |||||
| /// Get the number of parameters in this model | |||||
| /// </summary> | |||||
| public ulong ParameterCount => NativeHandle.ParameterCount; | |||||
| /// <summary> | /// <summary> | ||||
| /// Dimension of embedding vectors | /// Dimension of embedding vectors | ||||
| /// </summary> | /// </summary> | ||||
| @@ -53,11 +63,11 @@ namespace LLama | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static LLamaWeights LoadFromFile(IModelParams @params) | public static LLamaWeights LoadFromFile(IModelParams @params) | ||||
| { | { | ||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||||
| using var pin = @params.ToLlamaModelParams(out var lparams); | |||||
| var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | ||||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | if (!string.IsNullOrEmpty(@params.LoraAdapter)) | ||||
| weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | |||||
| weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraAdapterScale, @params.LoraBase, @params.Threads); | |||||
| return new LLamaWeights(weights, @params.Encoding); | return new LLamaWeights(weights, @params.Encoding); | ||||
| } | } | ||||
| @@ -19,32 +19,27 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// RNG seed, -1 for random | /// RNG seed, -1 for random | ||||
| /// </summary> | /// </summary> | ||||
| public int seed; | |||||
| public uint seed; | |||||
| /// <summary> | /// <summary> | ||||
| /// text context | /// text context | ||||
| /// </summary> | /// </summary> | ||||
| public int n_ctx; | |||||
| public uint n_ctx; | |||||
| /// <summary> | /// <summary> | ||||
| /// prompt processing batch size | /// prompt processing batch size | ||||
| /// </summary> | /// </summary> | ||||
| public int n_batch; | |||||
| public uint n_batch; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of layers to store in VRAM | |||||
| /// number of threads to use for generation | |||||
| /// </summary> | /// </summary> | ||||
| public int n_gpu_layers; | |||||
| public uint n_threads; | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | |||||
| /// number of threads to use for batch processing | |||||
| /// </summary> | /// </summary> | ||||
| public int main_gpu; | |||||
| /// <summary> | |||||
| /// how to split layers across multiple GPUs | |||||
| /// </summary> | |||||
| public nint tensor_split; | |||||
| public uint n_threads_batch; | |||||
| /// <summary> | /// <summary> | ||||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | ||||
| @@ -58,26 +53,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public float rope_freq_scale; | public float rope_freq_scale; | ||||
| /// <summary> | |||||
| /// called with a progress value between 0 and 1, pass NULL to disable | |||||
| /// </summary> | |||||
| public IntPtr progress_callback; | |||||
| /// <summary> | |||||
| /// context pointer passed to the progress callback | |||||
| /// </summary> | |||||
| public IntPtr progress_callback_user_data; | |||||
| /// <summary> | |||||
| /// if true, reduce VRAM usage at the cost of performance | |||||
| /// </summary> | |||||
| public bool low_vram | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_low_vram); | |||||
| set => _low_vram = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _low_vram; | |||||
| /// <summary> | /// <summary> | ||||
| /// if true, use experimental mul_mat_q kernels | /// if true, use experimental mul_mat_q kernels | ||||
| /// </summary> | /// </summary> | ||||
| @@ -108,36 +83,6 @@ namespace LLama.Native | |||||
| } | } | ||||
| private sbyte _logits_all; | private sbyte _logits_all; | ||||
| /// <summary> | |||||
| /// only load the vocabulary, no weights | |||||
| /// </summary> | |||||
| public bool vocab_only | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_vocab_only); | |||||
| set => _vocab_only = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _vocab_only; | |||||
| /// <summary> | |||||
| /// use mmap if possible | |||||
| /// </summary> | |||||
| public bool use_mmap | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mmap); | |||||
| set => _use_mmap = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mmap; | |||||
| /// <summary> | |||||
| /// force system to keep model in RAM | |||||
| /// </summary> | |||||
| public bool use_mlock | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mlock); | |||||
| set => _use_mlock = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mlock; | |||||
| /// <summary> | /// <summary> | ||||
| /// embedding mode only | /// embedding mode only | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,67 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | |||||
| { | |||||
| /// <summary> | |||||
| /// A C# representation of the llama.cpp `llama_model_params` struct | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public unsafe struct LLamaModelParams | |||||
| { | |||||
| /// <summary> | |||||
| /// // number of layers to store in VRAM | |||||
| /// </summary> | |||||
| public int n_gpu_layers; | |||||
| /// <summary> | |||||
| /// // the GPU that is used for scratch and small tensors | |||||
| /// </summary> | |||||
| public int main_gpu; | |||||
| /// <summary> | |||||
| /// how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) | |||||
| /// </summary> | |||||
| public float* tensor_split; | |||||
| /// <summary> | |||||
| /// called with a progress value between 0 and 1, pass NULL to disable | |||||
| /// </summary> | |||||
| LlamaProgressCallback progress_callback; | |||||
| /// <summary> | |||||
| /// context pointer passed to the progress callback | |||||
| /// </summary> | |||||
| void* progress_callback_user_data; | |||||
| /// <summary> | |||||
| /// only load the vocabulary, no weights | |||||
| /// </summary> | |||||
| public bool vocab_only | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_vocab_only); | |||||
| set => _vocab_only = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _vocab_only; | |||||
| /// <summary> | |||||
| /// use mmap if possible | |||||
| /// </summary> | |||||
| public bool use_mmap | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mmap); | |||||
| set => _use_mmap = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mmap; | |||||
| /// <summary> | |||||
| /// force system to keep model in RAM | |||||
| /// </summary> | |||||
| public bool use_mlock | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mlock); | |||||
| set => _use_mlock = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mlock; | |||||
| } | |||||
| } | |||||
| @@ -36,5 +36,15 @@ namespace LLama.Native | |||||
| set => _quantize_output_tensor = Convert.ToSByte(value); | set => _quantize_output_tensor = Convert.ToSByte(value); | ||||
| } | } | ||||
| private sbyte _quantize_output_tensor; | private sbyte _quantize_output_tensor; | ||||
| /// <summary> | |||||
| /// only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored | |||||
| /// </summary> | |||||
| public bool only_copy | |||||
| { | |||||
| get => Convert.ToBoolean(_only_copy); | |||||
| set => _only_copy = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _only_copy; | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,7 +2,6 @@ | |||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Common; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| #pragma warning disable IDE1006 // Naming Styles | #pragma warning disable IDE1006 // Naming Styles | ||||
| @@ -110,6 +109,13 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_empty_call(); | public static extern bool llama_empty_call(); | ||||
| /// <summary> | |||||
| /// Create a LLamaModelParams with default values | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaModelParams llama_model_default_params(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a LLamaContextParams with default values | /// Create a LLamaContextParams with default values | ||||
| /// </summary> | /// </summary> | ||||
| @@ -138,18 +144,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_mlock_supported(); | public static extern bool llama_mlock_supported(); | ||||
| /// <summary> | |||||
| /// Export a static computation graph for context of 511 and batch size of 1 | |||||
| /// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these | |||||
| /// parameters here to keep things simple | |||||
| /// IMPORTANT: do not use for anything else other than debugging and testing! | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="fname"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname); | |||||
| /// <summary> | /// <summary> | ||||
| /// Various functions for loading a ggml llama model. | /// Various functions for loading a ggml llama model. | ||||
| /// Allocate (almost) all memory needed for the model. | /// Allocate (almost) all memory needed for the model. | ||||
| @@ -159,7 +153,7 @@ namespace LLama.Native | |||||
| /// <param name="params"></param> | /// <param name="params"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params); | |||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params); | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new llama_context with the given model. | /// Create a new llama_context with the given model. | ||||
| @@ -192,7 +186,7 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_free_model(IntPtr model); | public static extern void llama_free_model(IntPtr model); | ||||
| /// <summary> | /// <summary> | ||||
| /// Apply a LoRA adapter to a loaded model | /// Apply a LoRA adapter to a loaded model | ||||
| /// path_base_model is the path to a higher quality model to use as a base for | /// path_base_model is the path to a higher quality model to use as a base for | ||||
| @@ -202,11 +196,12 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model_ptr"></param> | /// <param name="model_ptr"></param> | ||||
| /// <param name="path_lora"></param> | /// <param name="path_lora"></param> | ||||
| /// <param name="scale"></param> | |||||
| /// <param name="path_base_model"></param> | /// <param name="path_base_model"></param> | ||||
| /// <param name="n_threads"></param> | /// <param name="n_threads"></param> | ||||
| /// <returns>Returns 0 on success</returns> | /// <returns>Returns 0 on success</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads); | |||||
| public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the number of tokens in the KV cache | /// Returns the number of tokens in the KV cache | ||||
| @@ -222,7 +217,7 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="seed"></param> | /// <param name="seed"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, int seed); | |||||
| public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the maximum size in bytes of the state (rng, logits, embedding | /// Returns the maximum size in bytes of the state (rng, logits, embedding | ||||
| @@ -243,21 +238,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); | public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); | ||||
| /// <summary> | |||||
| /// Copies the state to the specified destination address. | |||||
| /// Destination needs to have allocated enough memory (see llama_get_state_size) | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="dest"></param> | |||||
| /// <returns>the number of bytes copied</returns> | |||||
| public static ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte[] dest) | |||||
| { | |||||
| fixed (byte* dstPtr = &dest[0]) | |||||
| { | |||||
| return llama_copy_state_data(ctx, dstPtr); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Set the state reading from the specified address | /// Set the state reading from the specified address | ||||
| /// </summary> | /// </summary> | ||||
| @@ -267,20 +247,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); | public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); | ||||
| /// <summary> | |||||
| /// Set the state reading from the specified address | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="src"></param> | |||||
| /// <returns>the number of bytes read</returns> | |||||
| public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src) | |||||
| { | |||||
| fixed (byte* srcPtr = &src[0]) | |||||
| { | |||||
| return llama_set_state_data(ctx, srcPtr); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Load session file | /// Load session file | ||||
| /// </summary> | /// </summary> | ||||
| @@ -313,24 +279,9 @@ namespace LLama.Native | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_tokens"></param> | /// <param name="n_tokens"></param> | ||||
| /// <param name="n_past"></param> | /// <param name="n_past"></param> | ||||
| /// <param name="n_threads"></param> | |||||
| /// <returns>Returns 0 on success</returns> | /// <returns>Returns 0 on success</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads); | |||||
| /// <summary> | |||||
| /// Run the llama inference to obtain the logits and probabilities for the next token. | |||||
| /// tokens + n_tokens is the provided batch of new tokens to process | |||||
| /// n_past is the number of tokens to use from previous eval calls | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="n_tokens"></param> | |||||
| /// <param name="n_past"></param> | |||||
| /// <param name="n_threads"></param> | |||||
| /// <returns>Returns 0 on success</returns> | |||||
| [DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads); | |||||
| public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert the provided text into tokens. | /// Convert the provided text into tokens. | ||||
| @@ -364,7 +315,7 @@ namespace LLama.Native | |||||
| // Do the actual tokenization | // Do the actual tokenization | ||||
| fixed (byte* arrayPtr = array) | fixed (byte* arrayPtr = array) | ||||
| fixed (llama_token* tokensPtr = tokens) | fixed (llama_token* tokensPtr = tokens) | ||||
| return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos); | |||||
| return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -372,28 +323,6 @@ namespace LLama.Native | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Convert the provided text into tokens. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="n_max_tokens"></param> | |||||
| /// <param name="add_bos"></param> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||||
| /// </returns> | |||||
| [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos); | |||||
| /// <summary> | |||||
| /// Get the number of tokens in the model vocabulary for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the size of the context window for the model for this context | /// Get the size of the context window for the model for this context | ||||
| /// </summary> | /// </summary> | ||||
| @@ -402,14 +331,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); | public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Get the dimension of embedding vectors from the model for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_embd(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Token logits obtained from the last call to llama_eval() | /// Token logits obtained from the last call to llama_eval() | ||||
| /// The logits for the last token are stored in the last row | /// The logits for the last token are stored in the last row | ||||
| @@ -431,15 +352,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); | public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Token Id -> String. Uses the vocabulary in the provided context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="token"></param> | |||||
| /// <returns>Pointer to a string.</returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the "Beginning of sentence" token | /// Get the "Beginning of sentence" token | ||||
| /// </summary> | /// </summary> | ||||
| @@ -488,7 +400,7 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_n_vocab(SafeLlamaModelHandle model); | |||||
| public static extern int llama_n_vocab(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the size of the context window for the model | /// Get the size of the context window for the model | ||||
| @@ -496,7 +408,7 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_n_ctx(SafeLlamaModelHandle model); | |||||
| public static extern int llama_n_ctx_train(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the dimension of embedding vectors from this model | /// Get the dimension of embedding vectors from this model | ||||
| @@ -504,7 +416,23 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_n_embd(SafeLlamaModelHandle model); | |||||
| public static extern int llama_n_embd(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Get the size of the model in bytes | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern ulong llama_model_size(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Get the number of parameters in this model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern ulong llama_model_n_params(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert a single token into text | /// Convert a single token into text | ||||
| @@ -515,13 +443,14 @@ namespace LLama.Native | |||||
| /// <param name="length">size of the buffer</param> | /// <param name="length">size of the buffer</param> | ||||
| /// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns> | /// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||||
| public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert text into tokens | /// Convert text into tokens | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="text_len"></param> | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_max_tokens"></param> | /// <param name="n_max_tokens"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| @@ -529,7 +458,7 @@ namespace LLama.Native | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | /// Returns a negative number on failure - the number of tokens that would have been returned | ||||
| /// </returns> | /// </returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); | |||||
| public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos); | |||||
| /// <summary> | /// <summary> | ||||
| /// Register a callback to receive llama log messages | /// Register a callback to receive llama log messages | ||||
| @@ -69,12 +69,13 @@ namespace LLama.Native | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| NativeApi.llama_free(DangerousGetHandle()); | |||||
| SetHandle(IntPtr.Zero); | |||||
| // Decrement refcount on model | // Decrement refcount on model | ||||
| _model?.DangerousRelease(); | _model?.DangerousRelease(); | ||||
| _model = null!; | _model = null!; | ||||
| NativeApi.llama_free(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -234,15 +235,14 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tokens">The provided batch of new tokens to process</param> | /// <param name="tokens">The provided batch of new tokens to process</param> | ||||
| /// <param name="n_past">the number of tokens to use from previous eval calls</param> | /// <param name="n_past">the number of tokens to use from previous eval calls</param> | ||||
| /// <param name="n_threads"></param> | |||||
| /// <returns>Returns true on success</returns> | /// <returns>Returns true on success</returns> | ||||
| public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads) | |||||
| public bool Eval(ReadOnlySpan<int> tokens, int n_past) | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| fixed (int* pinned = tokens) | fixed (int* pinned = tokens) | ||||
| { | { | ||||
| return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; | |||||
| return NativeApi.llama_eval(this, pinned, tokens.Length, n_past) == 0; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -29,18 +29,30 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingSize { get; } | public int EmbeddingSize { get; } | ||||
| /// <summary> | |||||
| /// Get the size of this model in bytes | |||||
| /// </summary> | |||||
| public ulong SizeInBytes { get; } | |||||
| /// <summary> | |||||
| /// Get the number of parameters in this model | |||||
| /// </summary> | |||||
| public ulong ParameterCount { get; } | |||||
| internal SafeLlamaModelHandle(IntPtr handle) | internal SafeLlamaModelHandle(IntPtr handle) | ||||
| : base(handle) | : base(handle) | ||||
| { | { | ||||
| VocabCount = NativeApi.llama_model_n_vocab(this); | |||||
| ContextSize = NativeApi.llama_model_n_ctx(this); | |||||
| EmbeddingSize = NativeApi.llama_model_n_embd(this); | |||||
| VocabCount = NativeApi.llama_n_vocab(this); | |||||
| ContextSize = NativeApi.llama_n_ctx_train(this); | |||||
| EmbeddingSize = NativeApi.llama_n_embd(this); | |||||
| SizeInBytes = NativeApi.llama_model_size(this); | |||||
| ParameterCount = NativeApi.llama_model_n_params(this); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| NativeApi.llama_free_model(handle); | |||||
| NativeApi.llama_free_model(DangerousGetHandle()); | |||||
| SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -52,7 +64,7 @@ namespace LLama.Native | |||||
| /// <param name="lparams"></param> | /// <param name="lparams"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams) | |||||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) | |||||
| { | { | ||||
| var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | ||||
| if (model_ptr == IntPtr.Zero) | if (model_ptr == IntPtr.Zero) | ||||
| @@ -62,19 +74,22 @@ namespace LLama.Native | |||||
| } | } | ||||
| #region LoRA | #region LoRA | ||||
| /// <summary> | /// <summary> | ||||
| /// Apply a LoRA adapter to a loaded model | /// Apply a LoRA adapter to a loaded model | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="lora"></param> | /// <param name="lora"></param> | ||||
| /// <param name="scale"></param> | |||||
| /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the | /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the | ||||
| /// adapter. Can be NULL to use the current loaded model.</param> | /// adapter. Can be NULL to use the current loaded model.</param> | ||||
| /// <param name="threads"></param> | /// <param name="threads"></param> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1) | |||||
| public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int threads = -1) | |||||
| { | { | ||||
| var err = NativeApi.llama_model_apply_lora_from_file( | var err = NativeApi.llama_model_apply_lora_from_file( | ||||
| this, | this, | ||||
| lora, | lora, | ||||
| scale, | |||||
| string.IsNullOrEmpty(modelBase) ? null : modelBase, | string.IsNullOrEmpty(modelBase) ? null : modelBase, | ||||
| threads | threads | ||||
| ); | ); | ||||
| @@ -97,7 +112,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| fixed (byte* destPtr = dest) | fixed (byte* destPtr = dest) | ||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length); | |||||
| var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length); | |||||
| return Math.Abs(length); | return Math.Abs(length); | ||||
| } | } | ||||
| } | } | ||||
| @@ -113,7 +128,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); | |||||
| var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); | |||||
| if (length == 0) | if (length == 0) | ||||
| return ""; | return ""; | ||||
| @@ -121,7 +136,7 @@ namespace LLama.Native | |||||
| fixed (byte* bytePtr = bytes) | fixed (byte* bytePtr = bytes) | ||||
| { | { | ||||
| var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); | |||||
| var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); | |||||
| Debug.Assert(written == bytes.Length); | Debug.Assert(written == bytes.Length); | ||||
| return encoding.GetString(bytePtr, bytes.Length); | return encoding.GetString(bytePtr, bytes.Length); | ||||
| @@ -139,7 +154,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); | |||||
| var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); | |||||
| if (length == 0) | if (length == 0) | ||||
| return; | return; | ||||
| @@ -147,7 +162,7 @@ namespace LLama.Native | |||||
| fixed (byte* bytePtr = bytes) | fixed (byte* bytePtr = bytes) | ||||
| { | { | ||||
| // Decode into bytes | // Decode into bytes | ||||
| var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); | |||||
| var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); | |||||
| Debug.Assert(written == bytes.Length); | Debug.Assert(written == bytes.Length); | ||||
| // Decode into chars | // Decode into chars | ||||
| @@ -276,13 +291,13 @@ namespace LLama.Native | |||||
| fixed (byte* bytesPtr = &bytes[0]) | fixed (byte* bytesPtr = &bytes[0]) | ||||
| { | { | ||||
| // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | ||||
| var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos); | |||||
| var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos); | |||||
| // Tokenize again, this time outputting into an array of exactly the right size | // Tokenize again, this time outputting into an array of exactly the right size | ||||
| var tokens = new int[count]; | var tokens = new int[count]; | ||||
| fixed (int* tokensPtr = &tokens[0]) | fixed (int* tokensPtr = &tokens[0]) | ||||
| { | { | ||||
| NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); | |||||
| NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos); | |||||
| return tokens; | return tokens; | ||||
| } | } | ||||
| } | } | ||||
| @@ -22,8 +22,8 @@ namespace LLama | |||||
| { | { | ||||
| using var weights = LLamaWeights.LoadFromFile(@params); | using var weights = LLamaWeights.LoadFromFile(@params); | ||||
| using (@params.ToLlamaContextParams(out var lparams)) | |||||
| return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); | |||||
| @params.ToLlamaContextParams(out var lparams); | |||||
| return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); | |||||
| } | } | ||||
| [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] | [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] | ||||
| @@ -47,11 +47,11 @@ namespace LLama | |||||
| [Obsolete("Use SafeLLamaContextHandle Eval method instead")] | [Obsolete("Use SafeLLamaContextHandle Eval method instead")] | ||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | ||||
| public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) | |||||
| public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | ||||
| { | { | ||||
| var slice = tokens.AsSpan().Slice(startIndex, n_tokens); | var slice = tokens.AsSpan().Slice(startIndex, n_tokens); | ||||
| return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; | |||||
| return ctx.Eval(slice, n_past) ? 0 : 1; | |||||
| } | } | ||||
| [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] | [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] | ||||