Major llama.cpp API Changetags/v0.6.0
| @@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | |||||
| var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); | |||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| @@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| ex.Context.Dispose(); | ex.Context.Dispose(); | ||||
| ex = new(new LLamaContext(parameters)); | |||||
| ex = new(new LLamaContext(model, parameters)); | |||||
| session = new ChatSession(ex); | session = new ChatSession(ex); | ||||
| session.LoadSession(statePath); | session.LoadSession(statePath); | ||||
| @@ -1,13 +1,7 @@ | |||||
| using System.Reflection.Metadata; | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| using System.Security.Cryptography; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using Microsoft.SemanticKernel; | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | using Microsoft.SemanticKernel.AI.ChatCompletion; | ||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| using LLamaSharp.SemanticKernel.ChatCompletion; | using LLamaSharp.SemanticKernel.ChatCompletion; | ||||
| using LLamaSharp.SemanticKernel.TextCompletion; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -22,7 +16,7 @@ namespace LLama.Examples.NewVersion | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue), | |||||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)), | |||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | using var model = LLamaWeights.LoadFromFile(parameters); | ||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| @@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| var seed = 1337; | |||||
| var seed = 1337u; | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| @@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue), | |||||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) | |||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | using var model = LLamaWeights.LoadFromFile(parameters); | ||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| @@ -15,7 +15,7 @@ namespace LLama.Examples.NewVersion | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var @params = new ModelParams(modelPath) | var @params = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue) | |||||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) | |||||
| }; | }; | ||||
| using var weights = LLamaWeights.LoadFromFile(@params); | using var weights = LLamaWeights.LoadFromFile(@params); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using LLama.Examples.NewVersion; | using LLama.Examples.NewVersion; | ||||
| using LLama.Native; | |||||
| Console.WriteLine("======================================================================================================"); | Console.WriteLine("======================================================================================================"); | ||||
| @@ -7,7 +8,7 @@ Console.WriteLine(" __ __ ____ _ | |||||
| Console.WriteLine("======================================================================================================"); | Console.WriteLine("======================================================================================================"); | ||||
| NativeApi.llama_empty_call(); | |||||
| Console.WriteLine(); | Console.WriteLine(); | ||||
| await NewVersionTestRunner.Run(); | await NewVersionTestRunner.Run(); | ||||
| @@ -27,36 +27,8 @@ namespace LLama.Unittest | |||||
| public void BasicModelProperties() | public void BasicModelProperties() | ||||
| { | { | ||||
| Assert.Equal(32000, _model.VocabCount); | Assert.Equal(32000, _model.VocabCount); | ||||
| Assert.Equal(2048, _model.ContextSize); | |||||
| Assert.Equal(4096, _model.ContextSize); | |||||
| Assert.Equal(4096, _model.EmbeddingSize); | Assert.Equal(4096, _model.EmbeddingSize); | ||||
| Assert.Equal(Encoding.UTF8, _model.Encoding); | |||||
| } | |||||
| [Fact] | |||||
| public void CloneContext() | |||||
| { | |||||
| var original = _model.CreateContext(_params); | |||||
| // Evaluate something (doesn't matter what, as long as it begins with token 1) | |||||
| original.Eval(new[] { 1, 42, 321 }, 0); | |||||
| // Clone current state | |||||
| var clone = original.Clone(); | |||||
| // Now evaluate something more | |||||
| var reply1a = original.Eval(new[] { 4, 5, 6 }, 3); | |||||
| var reply2a = original.Eval(new[] { 7, 8, 9 }, 6); | |||||
| // Assert that the context replied differently each time | |||||
| Assert.NotEqual(reply1a, reply2a); | |||||
| // Give the same prompts to the cloned state | |||||
| var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3); | |||||
| var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6); | |||||
| // Assert that the cloned context replied in the same way as originally | |||||
| Assert.Equal(reply1a, reply1b); | |||||
| Assert.Equal(reply2a, reply2b); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| { | { | ||||
| public class LLamaContextTests | |||||
| public sealed class LLamaContextTests | |||||
| : IDisposable | : IDisposable | ||||
| { | { | ||||
| private readonly LLamaWeights _weights; | private readonly LLamaWeights _weights; | ||||
| @@ -30,7 +30,6 @@ namespace LLama.Unittest | |||||
| Assert.Equal(768, _context.ContextSize); | Assert.Equal(768, _context.ContextSize); | ||||
| Assert.Equal(4096, _context.EmbeddingSize); | Assert.Equal(4096, _context.EmbeddingSize); | ||||
| Assert.Equal(32000, _context.VocabCount); | Assert.Equal(32000, _context.VocabCount); | ||||
| Assert.Equal(0, _context.KVCacheTokenCount); | |||||
| } | } | ||||
| [Fact] | [Fact] | ||||
| @@ -13,7 +13,6 @@ namespace LLama.Unittest | |||||
| { | { | ||||
| BatchSize = 17, | BatchSize = 17, | ||||
| ContextSize = 42, | ContextSize = 42, | ||||
| LoraAdapter = "adapter", | |||||
| Seed = 42, | Seed = 42, | ||||
| GpuLayerCount = 111 | GpuLayerCount = 111 | ||||
| }; | }; | ||||
| @@ -31,9 +30,13 @@ namespace LLama.Unittest | |||||
| { | { | ||||
| BatchSize = 17, | BatchSize = 17, | ||||
| ContextSize = 42, | ContextSize = 42, | ||||
| LoraAdapter = "adapter", | |||||
| Seed = 42, | Seed = 42, | ||||
| GpuLayerCount = 111 | |||||
| GpuLayerCount = 111, | |||||
| LoraAdapters = | |||||
| { | |||||
| new("abc", 1), | |||||
| new("def", 0) | |||||
| } | |||||
| }; | }; | ||||
| var settings = new Newtonsoft.Json.JsonSerializerSettings(); | var settings = new Newtonsoft.Json.JsonSerializerSettings(); | ||||
| @@ -16,7 +16,7 @@ namespace LLama.Unittest | |||||
| _params = new ModelParams(Constants.ModelPath) | _params = new ModelParams(Constants.ModelPath) | ||||
| { | { | ||||
| ContextSize = 60, | ContextSize = 60, | ||||
| Seed = 1754 | |||||
| Seed = 1754, | |||||
| }; | }; | ||||
| _weights = LLamaWeights.LoadFromFile(_params); | _weights = LLamaWeights.LoadFromFile(_params); | ||||
| } | } | ||||
| @@ -48,13 +48,13 @@ namespace LLama.Unittest | |||||
| { | { | ||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| const string question = " Question. why is a cat the best pet?\nAnswer: "; | |||||
| const string question = " Question. cats or dogs?\nAnswer: "; | |||||
| // The context size is set to 60. Generate more than that, forcing it to generate a coherent response | // The context size is set to 60. Generate more than that, forcing it to generate a coherent response | ||||
| // with a modified context | // with a modified context | ||||
| var @params = new InferenceParams() | var @params = new InferenceParams() | ||||
| { | { | ||||
| MaxTokens = 100, | |||||
| MaxTokens = 65, | |||||
| TokensKeep = question.Length, | TokensKeep = question.Length, | ||||
| }; | }; | ||||
| @@ -27,7 +27,7 @@ public sealed class TokenTests | |||||
| [Fact] | [Fact] | ||||
| public void TokensEndWith() | public void TokensEndWith() | ||||
| { | { | ||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | |||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); | |||||
| var result = tokens.TokensEndsWithAnyString(new[] | var result = tokens.TokensEndsWithAnyString(new[] | ||||
| { | { | ||||
| @@ -41,7 +41,7 @@ public sealed class TokenTests | |||||
| [Fact] | [Fact] | ||||
| public void TokensEndSubstring() | public void TokensEndSubstring() | ||||
| { | { | ||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | |||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); | |||||
| var result = tokens.TokensEndsWithAnyString((IList<string>)new[] | var result = tokens.TokensEndsWithAnyString((IList<string>)new[] | ||||
| { | { | ||||
| @@ -53,7 +53,7 @@ public sealed class TokenTests | |||||
| [Fact] | [Fact] | ||||
| public void TokensNotEndWith() | public void TokensNotEndWith() | ||||
| { | { | ||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | |||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); | |||||
| var result = tokens.TokensEndsWithAnyString((IList<string>)new[] | var result = tokens.TokensEndsWithAnyString((IList<string>)new[] | ||||
| { | { | ||||
| @@ -67,7 +67,7 @@ public sealed class TokenTests | |||||
| [Fact] | [Fact] | ||||
| public void TokensNotEndWithNothing() | public void TokensNotEndWithNothing() | ||||
| { | { | ||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | |||||
| var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); | |||||
| var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8); | var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8); | ||||
| Assert.False(result); | Assert.False(result); | ||||
| @@ -4,7 +4,7 @@ using LLama.Abstractions; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| public class ModelOptions | public class ModelOptions | ||||
| : IModelParams | |||||
| : ILLamaParams | |||||
| { | { | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| @@ -14,7 +14,7 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Model context size (n_ctx) | /// Model context size (n_ctx) | ||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize { get; set; } = 512; | |||||
| public uint ContextSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| /// </summary> | /// </summary> | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Seed for the random number generator (seed) | /// Seed for the random number generator (seed) | ||||
| /// </summary> | /// </summary> | ||||
| public int Seed { get; set; } = 1686349486; | |||||
| public uint Seed { get; set; } = 1686349486; | |||||
| /// <summary> | /// <summary> | ||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | /// Use f16 instead of f32 for memory kv (memory_f16) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -51,26 +51,31 @@ namespace LLama.Web.Common | |||||
| /// Model path (model) | /// Model path (model) | ||||
| /// </summary> | /// </summary> | ||||
| public string ModelPath { get; set; } | public string ModelPath { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// model alias | |||||
| /// </summary> | |||||
| public string ModelAlias { get; set; } = "unknown"; | |||||
| /// <summary> | |||||
| /// lora adapter path (lora_adapter) | |||||
| /// </summary> | |||||
| public string LoraAdapter { get; set; } = string.Empty; | |||||
| /// <summary> | |||||
| /// base model path for the lora adapter (lora_base) | |||||
| /// </summary> | |||||
| public string LoraBase { get; set; } = string.Empty; | |||||
| /// <summary> | |||||
| /// Number of threads (-1 = autodetect) (n_threads) | |||||
| /// List of LoRAs to apply | |||||
| /// </summary> | /// </summary> | ||||
| public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); | |||||
| public AdapterCollection LoraAdapters { get; set; } = new(); | |||||
| /// <summary> | |||||
| /// base model path for the lora adapter (lora_base) | |||||
| /// </summary> | |||||
| public string LoraBase { get; set; } = string.Empty; | |||||
| /// <summary> | /// <summary> | ||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | |||||
| /// Number of threads (null = autodetect) (n_threads) | |||||
| /// </summary> | /// </summary> | ||||
| public int BatchSize { get; set; } = 512; | |||||
| public uint? Threads { get; set; } | |||||
| /// <summary> | |||||
| /// Number of threads to use for batch processing (null = autodetect) (n_threads) | |||||
| /// </summary> | |||||
| public uint? BatchThreads { get; set; } | |||||
| /// <summary> | |||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | |||||
| /// </summary> | |||||
| public uint BatchSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether to convert eos to newline during the inference. | /// Whether to convert eos to newline during the inference. | ||||
| @@ -107,5 +112,10 @@ namespace LLama.Web.Common | |||||
| /// The encoding to use for models | /// The encoding to use for models | ||||
| /// </summary> | /// </summary> | ||||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | public Encoding Encoding { get; set; } = Encoding.UTF8; | ||||
| /// <summary> | |||||
| /// Load vocab only (no weights) | |||||
| /// </summary> | |||||
| public bool VocabOnly { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -3,7 +3,6 @@ using LLama.Web.Common; | |||||
| using LLama.Web.Models; | using LLama.Web.Models; | ||||
| using Microsoft.Extensions.Options; | using Microsoft.Extensions.Options; | ||||
| using System.Collections.Concurrent; | using System.Collections.Concurrent; | ||||
| using System.Drawing; | |||||
| namespace LLama.Web.Services | namespace LLama.Web.Services | ||||
| { | { | ||||
| @@ -50,15 +49,16 @@ namespace LLama.Web.Services | |||||
| if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) | if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) | ||||
| return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached")); | return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached")); | ||||
| // Create model | |||||
| var llamaModel = new LLamaContext(modelOption); | |||||
| // Load weights | |||||
| // todo: it would be better to have a central service which loads weights and shares them between all contexts that need them! | |||||
| using var weights = LLamaWeights.LoadFromFile(modelOption); | |||||
| // Create executor | // Create executor | ||||
| ILLamaExecutor executor = executorType switch | ILLamaExecutor executor = executorType switch | ||||
| { | { | ||||
| LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), | |||||
| LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), | |||||
| LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), | |||||
| LLamaExecutorType.Interactive => new InteractiveExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext | |||||
| LLamaExecutorType.Instruct => new InstructExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext | |||||
| LLamaExecutorType.Stateless => new StatelessExecutor(weights, modelOption), | |||||
| _ => default | _ => default | ||||
| }; | }; | ||||
| @@ -16,10 +16,15 @@ public class StatefulChatService : IDisposable | |||||
| public StatefulChatService(IConfiguration configuration) | public StatefulChatService(IConfiguration configuration) | ||||
| { | { | ||||
| _context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"]) | |||||
| var @params = new Common.ModelParams(configuration["ModelPath"]) | |||||
| { | { | ||||
| ContextSize = 512 | |||||
| }); | |||||
| ContextSize = 512, | |||||
| }; | |||||
| // todo: share weights from a central service | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| _context = new LLamaContext(weights, @params); | |||||
| _session = new ChatSession(new InteractiveExecutor(_context)); | _session = new ChatSession(new InteractiveExecutor(_context)); | ||||
| } | } | ||||
| @@ -12,10 +12,16 @@ namespace LLama.WebAPI.Services | |||||
| public StatelessChatService(IConfiguration configuration) | public StatelessChatService(IConfiguration configuration) | ||||
| { | { | ||||
| _context = new LLamaContext(new ModelParams(configuration["ModelPath"]) | |||||
| var @params = new Common.ModelParams(configuration["ModelPath"]) | |||||
| { | { | ||||
| ContextSize = 512, | ContextSize = 512, | ||||
| }); | |||||
| }; | |||||
| // todo: share weights from a central service | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| _context = new LLamaContext(weights, @params); | |||||
| // TODO: replace with a stateless executor | // TODO: replace with a stateless executor | ||||
| _session = new ChatSession(new InteractiveExecutor(_context)) | _session = new ChatSession(new InteractiveExecutor(_context)) | ||||
| .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) | .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) | ||||
| @@ -0,0 +1,70 @@ | |||||
| using System.Text; | |||||
| namespace LLama.Abstractions; | |||||
| /// <summary> | |||||
| /// The parameters for initializing a LLama context from a model. | |||||
| /// </summary> | |||||
| public interface IContextParams | |||||
| { | |||||
| /// <summary> | |||||
| /// Model context size (n_ctx) | |||||
| /// </summary> | |||||
| uint ContextSize { get; set; } | |||||
| /// <summary> | |||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | |||||
| /// </summary> | |||||
| uint BatchSize { get; set; } | |||||
| /// <summary> | |||||
| /// Seed for the random number generator (seed) | |||||
| /// </summary> | |||||
| uint Seed { get; set; } | |||||
| /// <summary> | |||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | |||||
| /// </summary> | |||||
| bool UseFp16Memory { get; set; } | |||||
| /// <summary> | |||||
| /// Compute perplexity over the prompt (perplexity) | |||||
| /// </summary> | |||||
| bool Perplexity { get; set; } | |||||
| /// <summary> | |||||
| /// Whether to use embedding mode. (embedding) Note that if this is set to true, | |||||
| /// The LLamaModel won't produce text response anymore. | |||||
| /// </summary> | |||||
| bool EmbeddingMode { get; set; } | |||||
| /// <summary> | |||||
| /// RoPE base frequency | |||||
| /// </summary> | |||||
| float RopeFrequencyBase { get; set; } | |||||
| /// <summary> | |||||
| /// RoPE frequency scaling factor | |||||
| /// </summary> | |||||
| float RopeFrequencyScale { get; set; } | |||||
| /// <summary> | |||||
| /// Use experimental mul_mat_q kernels | |||||
| /// </summary> | |||||
| bool MulMatQ { get; set; } | |||||
| /// <summary> | |||||
| /// The encoding to use for models | |||||
| /// </summary> | |||||
| Encoding Encoding { get; set; } | |||||
| /// <summary> | |||||
| /// Number of threads (null = autodetect) (n_threads) | |||||
| /// </summary> | |||||
| uint? Threads { get; set; } | |||||
| /// <summary> | |||||
| /// Number of threads to use for batch processing (null = autodetect) (n_threads) | |||||
| /// </summary> | |||||
| uint? BatchThreads { get; set; } | |||||
| } | |||||
| @@ -36,7 +36,7 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| public int TopK { get; set; } | public int TopK { get; set; } | ||||
| /// <summary> | |||||
| /// <summary>llama_eval | |||||
| /// 1.0 = disabled | /// 1.0 = disabled | ||||
| /// </summary> | /// </summary> | ||||
| public float TopP { get; set; } | public float TopP { get; set; } | ||||
| @@ -0,0 +1,11 @@ | |||||
| namespace LLama.Abstractions | |||||
| { | |||||
| /// <summary> | |||||
| /// Convenience interface for implementing both type of parameters. | |||||
| /// </summary> | |||||
| /// <remarks>Mostly exists for backwards compatibility reasons, when these two were not split.</remarks> | |||||
| public interface ILLamaParams | |||||
| : IModelParams, IContextParams | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,6 @@ | |||||
| using System.Text; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -7,36 +9,16 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| public interface IModelParams | public interface IModelParams | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Model context size (n_ctx) | |||||
| /// </summary> | |||||
| int ContextSize { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| /// </summary> | /// </summary> | ||||
| int MainGpu { get; set; } | int MainGpu { get; set; } | ||||
| /// <summary> | |||||
| /// if true, reduce VRAM usage at the cost of performance | |||||
| /// </summary> | |||||
| bool LowVram { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) | /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) | ||||
| /// </summary> | /// </summary> | ||||
| int GpuLayerCount { get; set; } | int GpuLayerCount { get; set; } | ||||
| /// <summary> | |||||
| /// Seed for the random number generator (seed) | |||||
| /// </summary> | |||||
| int Seed { get; set; } | |||||
| /// <summary> | |||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | |||||
| /// </summary> | |||||
| bool UseFp16Memory { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Use mmap for faster loads (use_mmap) | /// Use mmap for faster loads (use_mmap) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -47,41 +29,15 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| bool UseMemoryLock { get; set; } | bool UseMemoryLock { get; set; } | ||||
| /// <summary> | |||||
| /// Compute perplexity over the prompt (perplexity) | |||||
| /// </summary> | |||||
| bool Perplexity { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Model path (model) | /// Model path (model) | ||||
| /// </summary> | /// </summary> | ||||
| string ModelPath { get; set; } | string ModelPath { get; set; } | ||||
| /// <summary> | |||||
| /// lora adapter path (lora_adapter) | |||||
| /// </summary> | |||||
| string LoraAdapter { get; set; } | |||||
| /// <summary> | |||||
| /// base model path for the lora adapter (lora_base) | |||||
| /// </summary> | |||||
| string LoraBase { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Number of threads (-1 = autodetect) (n_threads) | /// Number of threads (-1 = autodetect) (n_threads) | ||||
| /// </summary> | /// </summary> | ||||
| int Threads { get; set; } | |||||
| /// <summary> | |||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | |||||
| /// </summary> | |||||
| int BatchSize { get; set; } | |||||
| /// <summary> | |||||
| /// Whether to use embedding mode. (embedding) Note that if this is set to true, | |||||
| /// The LLamaModel won't produce text response anymore. | |||||
| /// </summary> | |||||
| bool EmbeddingMode { get; set; } | |||||
| uint? Threads { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// how split tensors should be distributed across GPUs | /// how split tensors should be distributed across GPUs | ||||
| @@ -89,23 +45,62 @@ namespace LLama.Abstractions | |||||
| float[]? TensorSplits { get; set; } | float[]? TensorSplits { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// RoPE base frequency | |||||
| /// Load vocab only (no weights) | |||||
| /// </summary> | /// </summary> | ||||
| float RopeFrequencyBase { get; set; } | |||||
| bool VocabOnly { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// RoPE frequency scaling factor | |||||
| /// List of LoRA adapters to apply | |||||
| /// </summary> | /// </summary> | ||||
| float RopeFrequencyScale { get; set; } | |||||
| AdapterCollection LoraAdapters { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Use experimental mul_mat_q kernels | |||||
| /// base model path for the lora adapter (lora_base) | |||||
| /// </summary> | /// </summary> | ||||
| bool MulMatQ { get; set; } | |||||
| string LoraBase { get; set; } | |||||
| } | |||||
| /// <summary> | |||||
| /// The encoding to use for models | |||||
| /// </summary> | |||||
| Encoding Encoding { get; set; } | |||||
| /// <summary> | |||||
| /// A LoRA adapter to apply to a model | |||||
| /// </summary> | |||||
| /// <param name="Path">Path to the LoRA file</param> | |||||
| /// <param name="Scale">Strength of this LoRA</param> | |||||
| public readonly record struct LoraAdapter(string Path, float Scale); | |||||
| /// <summary> | |||||
| /// A list of LoraAdapter objects | |||||
| /// </summary> | |||||
| public sealed class AdapterCollection | |||||
| : List<LoraAdapter>, IEquatable<AdapterCollection> | |||||
| { | |||||
| /// <inheritdoc /> | |||||
| public bool Equals(AdapterCollection? other) | |||||
| { | |||||
| if (other == null) | |||||
| return false; | |||||
| return this.SequenceEqual(other); | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public override bool Equals(object? obj) | |||||
| { | |||||
| return Equals(obj as AdapterCollection); | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public override int GetHashCode() | |||||
| { | |||||
| unchecked | |||||
| { | |||||
| var hash = 17; | |||||
| for (var i = 0; i < Count; i++) | |||||
| { | |||||
| hash += this[i].GetHashCode(); | |||||
| hash *= 7823; | |||||
| } | |||||
| return hash; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,20 +10,17 @@ namespace LLama.Common | |||||
| /// The parameters for initializing a LLama model. | /// The parameters for initializing a LLama model. | ||||
| /// </summary> | /// </summary> | ||||
| public record ModelParams | public record ModelParams | ||||
| : IModelParams | |||||
| : ILLamaParams | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Model context size (n_ctx) | /// Model context size (n_ctx) | ||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize { get; set; } = 512; | |||||
| public uint ContextSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| /// </summary> | /// </summary> | ||||
| public int MainGpu { get; set; } = 0; | public int MainGpu { get; set; } = 0; | ||||
| /// <summary> | |||||
| /// if true, reduce VRAM usage at the cost of performance | |||||
| /// </summary> | |||||
| public bool LowVram { get; set; } = false; | |||||
| /// <summary> | /// <summary> | ||||
| /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) | /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -31,7 +28,7 @@ namespace LLama.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// Seed for the random number generator (seed) | /// Seed for the random number generator (seed) | ||||
| /// </summary> | /// </summary> | ||||
| public int Seed { get; set; } = 1686349486; | |||||
| public uint Seed { get; set; } = 1686349486; | |||||
| /// <summary> | /// <summary> | ||||
| /// Use f16 instead of f32 for memory kv (memory_f16) | /// Use f16 instead of f32 for memory kv (memory_f16) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -52,22 +49,31 @@ namespace LLama.Common | |||||
| /// Model path (model) | /// Model path (model) | ||||
| /// </summary> | /// </summary> | ||||
| public string ModelPath { get; set; } | public string ModelPath { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// lora adapter path (lora_adapter) | |||||
| /// List of LoRAs to apply | |||||
| /// </summary> | /// </summary> | ||||
| public string LoraAdapter { get; set; } = string.Empty; | |||||
| public AdapterCollection LoraAdapters { get; set; } = new(); | |||||
| /// <summary> | /// <summary> | ||||
| /// base model path for the lora adapter (lora_base) | /// base model path for the lora adapter (lora_base) | ||||
| /// </summary> | /// </summary> | ||||
| public string LoraBase { get; set; } = string.Empty; | public string LoraBase { get; set; } = string.Empty; | ||||
| /// <summary> | /// <summary> | ||||
| /// Number of threads (-1 = autodetect) (n_threads) | |||||
| /// Number of threads (null = autodetect) (n_threads) | |||||
| /// </summary> | /// </summary> | ||||
| public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); | |||||
| public uint? Threads { get; set; } | |||||
| /// <summary> | |||||
| /// Number of threads to use for batch processing (null = autodetect) (n_threads) | |||||
| /// </summary> | |||||
| public uint? BatchThreads { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) | ||||
| /// </summary> | /// </summary> | ||||
| public int BatchSize { get; set; } = 512; | |||||
| public uint BatchSize { get; set; } = 512; | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether to use embedding mode. (embedding) Note that if this is set to true, | /// Whether to use embedding mode. (embedding) Note that if this is set to true, | ||||
| @@ -95,6 +101,11 @@ namespace LLama.Common | |||||
| /// </summary> | /// </summary> | ||||
| public bool MulMatQ { get; set; } | public bool MulMatQ { get; set; } | ||||
| /// <summary> | |||||
| /// Load vocab only (no weights) | |||||
| /// </summary> | |||||
| public bool VocabOnly { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// The encoding to use to convert text for the model | /// The encoding to use to convert text for the model | ||||
| /// </summary> | /// </summary> | ||||
| @@ -138,10 +149,10 @@ namespace LLama.Common | |||||
| /// <param name="mulMatQ">Use experimental mul_mat_q kernels</param> | /// <param name="mulMatQ">Use experimental mul_mat_q kernels</param> | ||||
| /// <param name="encoding">The encoding to use to convert text for the model</param> | /// <param name="encoding">The encoding to use to convert text for the model</param> | ||||
| [Obsolete("Use object initializer to set all optional parameters")] | [Obsolete("Use object initializer to set all optional parameters")] | ||||
| public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, | |||||
| int seed = 1337, bool useFp16Memory = true, | |||||
| public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount = 20, | |||||
| uint seed = 1337, bool useFp16Memory = true, | |||||
| bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | ||||
| string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, | |||||
| string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512, | |||||
| bool embeddingMode = false, | bool embeddingMode = false, | ||||
| float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, | float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, | ||||
| string encoding = "UTF-8") | string encoding = "UTF-8") | ||||
| @@ -154,15 +165,15 @@ namespace LLama.Common | |||||
| UseMemoryLock = useMemoryLock; | UseMemoryLock = useMemoryLock; | ||||
| Perplexity = perplexity; | Perplexity = perplexity; | ||||
| ModelPath = modelPath; | ModelPath = modelPath; | ||||
| LoraAdapter = loraAdapter; | |||||
| LoraBase = loraBase; | LoraBase = loraBase; | ||||
| Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads; | |||||
| Threads = threads < 1 ? null : (uint)threads; | |||||
| BatchSize = batchSize; | BatchSize = batchSize; | ||||
| EmbeddingMode = embeddingMode; | EmbeddingMode = embeddingMode; | ||||
| RopeFrequencyBase = ropeFrequencyBase; | RopeFrequencyBase = ropeFrequencyBase; | ||||
| RopeFrequencyScale = ropeFrequencyScale; | RopeFrequencyScale = ropeFrequencyScale; | ||||
| MulMatQ = mulMatQ; | MulMatQ = mulMatQ; | ||||
| Encoding = Encoding.GetEncoding(encoding); | Encoding = Encoding.GetEncoding(encoding); | ||||
| LoraAdapters.Add(new LoraAdapter(loraAdapter, 1)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,46 @@ | |||||
| using System; | |||||
| using System.IO; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Extensions | |||||
| { | |||||
| /// <summary> | |||||
| /// Extention methods to the IContextParams interface | |||||
| /// </summary> | |||||
| public static class IContextParamsExtensions | |||||
| { | |||||
| /// <summary> | |||||
| /// Convert the given `IModelParams` into a `LLamaContextParams` | |||||
| /// </summary> | |||||
| /// <param name="params"></param> | |||||
| /// <param name="result"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="FileNotFoundException"></exception> | |||||
| /// <exception cref="ArgumentException"></exception> | |||||
| public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result) | |||||
| { | |||||
| result = NativeApi.llama_context_default_params(); | |||||
| result.n_ctx = @params.ContextSize; | |||||
| result.n_batch = @params.BatchSize; | |||||
| result.seed = @params.Seed; | |||||
| result.f16_kv = @params.UseFp16Memory; | |||||
| result.logits_all = @params.Perplexity; | |||||
| result.embedding = @params.EmbeddingMode; | |||||
| result.rope_freq_base = @params.RopeFrequencyBase; | |||||
| result.rope_freq_scale = @params.RopeFrequencyScale; | |||||
| result.mul_mat_q = @params.MulMatQ; | |||||
| result.n_threads = Threads(@params.Threads); | |||||
| result.n_threads_batch = Threads(@params.BatchThreads); | |||||
| } | |||||
| private static uint Threads(uint? value) | |||||
| { | |||||
| if (value is > 0) | |||||
| return (uint)value; | |||||
| return (uint)Math.Max(Environment.ProcessorCount / 2, 1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,41 +12,30 @@ namespace LLama.Extensions | |||||
| public static class IModelParamsExtensions | public static class IModelParamsExtensions | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Convert the given `IModelParams` into a `LLamaContextParams` | |||||
| /// Convert the given `IModelParams` into a `LLamaModelParams` | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="params"></param> | /// <param name="params"></param> | ||||
| /// <param name="result"></param> | /// <param name="result"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="FileNotFoundException"></exception> | /// <exception cref="FileNotFoundException"></exception> | ||||
| /// <exception cref="ArgumentException"></exception> | /// <exception cref="ArgumentException"></exception> | ||||
| public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result) | |||||
| public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result) | |||||
| { | { | ||||
| if (!File.Exists(@params.ModelPath)) | |||||
| throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); | |||||
| if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) | if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) | ||||
| throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); | throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); | ||||
| result = NativeApi.llama_context_default_params(); | |||||
| result.n_ctx = @params.ContextSize; | |||||
| result.n_batch = @params.BatchSize; | |||||
| result = NativeApi.llama_model_default_params(); | |||||
| result.main_gpu = @params.MainGpu; | result.main_gpu = @params.MainGpu; | ||||
| result.n_gpu_layers = @params.GpuLayerCount; | result.n_gpu_layers = @params.GpuLayerCount; | ||||
| result.seed = @params.Seed; | |||||
| result.f16_kv = @params.UseFp16Memory; | |||||
| result.use_mmap = @params.UseMemorymap; | |||||
| result.use_mlock = @params.UseMemoryLock; | result.use_mlock = @params.UseMemoryLock; | ||||
| result.logits_all = @params.Perplexity; | |||||
| result.embedding = @params.EmbeddingMode; | |||||
| result.low_vram = @params.LowVram; | |||||
| result.rope_freq_base = @params.RopeFrequencyBase; | |||||
| result.rope_freq_scale = @params.RopeFrequencyScale; | |||||
| result.mul_mat_q = @params.MulMatQ; | |||||
| result.use_mmap = @params.UseMemorymap; | |||||
| result.vocab_only = @params.VocabOnly; | |||||
| var pin = @params.TensorSplits.AsMemory().Pin(); | var pin = @params.TensorSplits.AsMemory().Pin(); | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| result.tensor_split = (nint)pin.Pointer; | |||||
| result.tensor_split = (float*)pin.Pointer; | |||||
| } | } | ||||
| return pin; | return pin; | ||||
| @@ -42,14 +42,9 @@ namespace LLama | |||||
| public int EmbeddingSize => _ctx.EmbeddingSize; | public int EmbeddingSize => _ctx.EmbeddingSize; | ||||
| /// <summary> | /// <summary> | ||||
| /// Get the number of tokens in the KV Cache for this context | |||||
| /// The context params set for this context | |||||
| /// </summary> | /// </summary> | ||||
| public int KVCacheTokenCount => _ctx.KVCacheTokenCount; | |||||
| /// <summary> | |||||
| /// The model params set for this model. | |||||
| /// </summary> | |||||
| public IModelParams Params { get; set; } | |||||
| public IContextParams Params { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// The native handle, which is used to be passed to the native APIs | /// The native handle, which is used to be passed to the native APIs | ||||
| @@ -62,24 +57,7 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public Encoding Encoding => _encoding; | public Encoding Encoding => _encoding; | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="params">Model params.</param> | |||||
| /// <param name="logger">The logger.</param> | |||||
| [Obsolete("Use the LLamaWeights.CreateContext instead")] | |||||
| public LLamaContext(IModelParams @params, ILogger? logger = null) | |||||
| { | |||||
| Params = @params; | |||||
| _logger = logger; | |||||
| _encoding = @params.Encoding; | |||||
| _logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}"); | |||||
| _ctx = Utils.InitLLamaContextFromModelParams(Params); | |||||
| } | |||||
| internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null) | |||||
| internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) | |||||
| { | { | ||||
| Params = @params; | Params = @params; | ||||
| @@ -95,7 +73,7 @@ namespace LLama | |||||
| /// <param name="params"></param> | /// <param name="params"></param> | ||||
| /// <param name="logger"></param> | /// <param name="logger"></param> | ||||
| /// <exception cref="ObjectDisposedException"></exception> | /// <exception cref="ObjectDisposedException"></exception> | ||||
| public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null) | |||||
| public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null) | |||||
| { | { | ||||
| if (model.NativeHandle.IsClosed) | if (model.NativeHandle.IsClosed) | ||||
| throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); | throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); | ||||
| @@ -105,30 +83,20 @@ namespace LLama | |||||
| _logger = logger; | _logger = logger; | ||||
| _encoding = @params.Encoding; | _encoding = @params.Encoding; | ||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||||
| @params.ToLlamaContextParams(out var lparams); | |||||
| _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a copy of the current state of this context | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public LLamaContext Clone() | |||||
| { | |||||
| using var pin = Params.ToLlamaContextParams(out var lparams); | |||||
| var clone = _ctx.Clone(lparams); | |||||
| return new LLamaContext(clone, Params); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Tokenize a string. | /// Tokenize a string. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="addBos">Whether to add a bos to the text.</param> | /// <param name="addBos">Whether to add a bos to the text.</param> | ||||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public llama_token[] Tokenize(string text, bool addBos = true) | |||||
| public llama_token[] Tokenize(string text, bool addBos = true, bool special = false) | |||||
| { | { | ||||
| return _ctx.Tokenize(text, addBos, _encoding); | |||||
| return _ctx.Tokenize(text, addBos, special, _encoding); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -177,19 +145,6 @@ namespace LLama | |||||
| fileStream.SetLength(writtenBytes); | fileStream.SetLength(writtenBytes); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Get the state data as a byte array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")] | |||||
| public byte[] GetStateData() | |||||
| { | |||||
| var stateSize = NativeApi.llama_get_state_size(_ctx); | |||||
| byte[] stateMemory = new byte[stateSize]; | |||||
| NativeApi.llama_copy_state_data(_ctx, stateMemory); | |||||
| return stateMemory; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the state data as an opaque handle | /// Get the state data as an opaque handle | ||||
| /// </summary> | /// </summary> | ||||
| @@ -198,31 +153,28 @@ namespace LLama | |||||
| { | { | ||||
| var stateSize = _ctx.GetStateSize(); | var stateSize = _ctx.GetStateSize(); | ||||
| unsafe | |||||
| // Allocate a chunk of memory large enough to hold the entire state | |||||
| var memory = Marshal.AllocHGlobal((nint)stateSize); | |||||
| try | |||||
| { | { | ||||
| // Allocate a chunk of memory large enough to hold the entire state | |||||
| var memory = Marshal.AllocHGlobal((nint)stateSize); | |||||
| try | |||||
| { | |||||
| // Copy the state data into memory, discover the actual size required | |||||
| var actualSize = _ctx.GetState(memory, stateSize); | |||||
| // Copy the state data into memory, discover the actual size required | |||||
| var actualSize = _ctx.GetState(memory, stateSize); | |||||
| // Shrink to size | |||||
| memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); | |||||
| // Shrink to size | |||||
| memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); | |||||
| // Wrap memory in a "state" | |||||
| var state = new State(memory); | |||||
| // Wrap memory in a "state" | |||||
| var state = new State(memory); | |||||
| // Set memory to zero, to prevent it being freed in finally block | |||||
| memory = IntPtr.Zero; | |||||
| // Set memory to zero, to prevent it being freed in finally block | |||||
| memory = IntPtr.Zero; | |||||
| return state; | |||||
| } | |||||
| finally | |||||
| { | |||||
| if (memory != IntPtr.Zero) | |||||
| Marshal.FreeHGlobal(memory); | |||||
| } | |||||
| return state; | |||||
| } | |||||
| finally | |||||
| { | |||||
| if (memory != IntPtr.Zero) | |||||
| Marshal.FreeHGlobal(memory); | |||||
| } | } | ||||
| } | } | ||||
| @@ -247,21 +199,6 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Load the state from memory. | |||||
| /// </summary> | |||||
| /// <param name="stateData"></param> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public void LoadState(byte[] stateData) | |||||
| { | |||||
| int stateSize = (int)NativeApi.llama_get_state_size(_ctx); | |||||
| if (stateData.Length > stateSize) | |||||
| { | |||||
| throw new RuntimeError("Failed to validate state size."); | |||||
| } | |||||
| NativeApi.llama_set_state_data(_ctx, stateData); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Load the state from memory. | /// Load the state from memory. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -463,15 +400,15 @@ namespace LLama | |||||
| public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount) | public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount) | ||||
| { | { | ||||
| var total = tokens.Length; | var total = tokens.Length; | ||||
| for(var i = 0; i < total; i += Params.BatchSize) | |||||
| for(var i = 0; i < total; i += (int)Params.BatchSize) | |||||
| { | { | ||||
| var n_eval = total - i; | var n_eval = total - i; | ||||
| if (n_eval > Params.BatchSize) | if (n_eval > Params.BatchSize) | ||||
| { | { | ||||
| n_eval = Params.BatchSize; | |||||
| n_eval = (int)Params.BatchSize; | |||||
| } | } | ||||
| if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) | |||||
| if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount)) | |||||
| { | { | ||||
| _logger?.LogError($"[LLamaContext] Failed to eval."); | _logger?.LogError($"[LLamaContext] Failed to eval."); | ||||
| throw new RuntimeError("Failed to eval."); | throw new RuntimeError("Failed to eval."); | ||||
| @@ -18,19 +18,22 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingSize => _ctx.EmbeddingSize; | public int EmbeddingSize => _ctx.EmbeddingSize; | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="params"></param> | |||||
| public LLamaEmbedder(IModelParams @params) | |||||
| public LLamaEmbedder(ILLamaParams allParams) | |||||
| : this(allParams, allParams) | |||||
| { | { | ||||
| @params.EmbeddingMode = true; | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| _ctx = weights.CreateContext(@params); | |||||
| } | } | ||||
| public LLamaEmbedder(LLamaWeights weights, IModelParams @params) | |||||
| public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) | |||||
| { | { | ||||
| using var weights = LLamaWeights.LoadFromFile(modelParams); | |||||
| contextParams.EmbeddingMode = true; | |||||
| _ctx = weights.CreateContext(contextParams); | |||||
| } | |||||
| public LLamaEmbedder(LLamaWeights weights, IContextParams @params) | |||||
| { | |||||
| @params.EmbeddingMode = true; | |||||
| _ctx = weights.CreateContext(@params); | _ctx = weights.CreateContext(@params); | ||||
| } | } | ||||
| @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; | |||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Native; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -20,7 +21,7 @@ namespace LLama | |||||
| : ILLamaExecutor | : ILLamaExecutor | ||||
| { | { | ||||
| private readonly LLamaWeights _weights; | private readonly LLamaWeights _weights; | ||||
| private readonly IModelParams _params; | |||||
| private readonly IContextParams _params; | |||||
| /// <summary> | /// <summary> | ||||
| /// The context used by the executor when running the inference. | /// The context used by the executor when running the inference. | ||||
| @@ -32,7 +33,7 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="weights"></param> | /// <param name="weights"></param> | ||||
| /// <param name="params"></param> | /// <param name="params"></param> | ||||
| public StatelessExecutor(LLamaWeights weights, IModelParams @params) | |||||
| public StatelessExecutor(LLamaWeights weights, IContextParams @params) | |||||
| { | { | ||||
| _weights = weights; | _weights = weights; | ||||
| _params = @params; | _params = @params; | ||||
| @@ -41,20 +42,6 @@ namespace LLama | |||||
| Context.Dispose(); | Context.Dispose(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a new stateless executor which will use the model used to create the given context | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] | |||||
| public StatelessExecutor(LLamaContext context) | |||||
| { | |||||
| _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); | |||||
| _params = context.Params; | |||||
| Context = _weights.CreateContext(_params); | |||||
| Context.Dispose(); | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| @@ -114,15 +101,16 @@ namespace LLama | |||||
| break; | break; | ||||
| // when run out of context | // when run out of context | ||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 | |||||
| if (n_past + tokens.Count > Context.ContextSize) | |||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | |||||
| if (n_past + tokens.Count >= Context.ContextSize) | |||||
| { | { | ||||
| var n_left = n_past - inferenceParams.TokensKeep; | |||||
| var n_left = n_past - inferenceParams.TokensKeep - 1; | |||||
| var n_discard = n_left / 2; | |||||
| n_past = Math.Max(1, inferenceParams.TokensKeep); | |||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||||
| NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| tokens.Clear(); | |||||
| tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); | |||||
| n_past -= n_discard; | |||||
| } | } | ||||
| // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | ||||
| @@ -1,5 +1,4 @@ | |||||
| using System; | using System; | ||||
| using System.Text; | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| @@ -20,11 +19,6 @@ namespace LLama | |||||
| /// <remarks>Be careful how you use this!</remarks> | /// <remarks>Be careful how you use this!</remarks> | ||||
| public SafeLlamaModelHandle NativeHandle => _weights; | public SafeLlamaModelHandle NativeHandle => _weights; | ||||
| /// <summary> | |||||
| /// Encoding to use to convert text into bytes for the model | |||||
| /// </summary> | |||||
| public Encoding Encoding { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Total number of tokens in vocabulary of this model | /// Total number of tokens in vocabulary of this model | ||||
| /// </summary> | /// </summary> | ||||
| @@ -35,15 +29,24 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize => NativeHandle.ContextSize; | public int ContextSize => NativeHandle.ContextSize; | ||||
| /// <summary> | |||||
| /// Get the size of this model in bytes | |||||
| /// </summary> | |||||
| public ulong SizeInBytes => NativeHandle.SizeInBytes; | |||||
| /// <summary> | |||||
| /// Get the number of parameters in this model | |||||
| /// </summary> | |||||
| public ulong ParameterCount => NativeHandle.ParameterCount; | |||||
| /// <summary> | /// <summary> | ||||
| /// Dimension of embedding vectors | /// Dimension of embedding vectors | ||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingSize => NativeHandle.EmbeddingSize; | public int EmbeddingSize => NativeHandle.EmbeddingSize; | ||||
| internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) | |||||
| internal LLamaWeights(SafeLlamaModelHandle weights) | |||||
| { | { | ||||
| _weights = weights; | _weights = weights; | ||||
| Encoding = encoding; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -53,13 +56,20 @@ namespace LLama | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static LLamaWeights LoadFromFile(IModelParams @params) | public static LLamaWeights LoadFromFile(IModelParams @params) | ||||
| { | { | ||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||||
| using var pin = @params.ToLlamaModelParams(out var lparams); | |||||
| var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | ||||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | |||||
| weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | |||||
| foreach (var adapter in @params.LoraAdapters) | |||||
| { | |||||
| if (string.IsNullOrEmpty(adapter.Path)) | |||||
| continue; | |||||
| if (adapter.Scale <= 0) | |||||
| continue; | |||||
| weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads); | |||||
| } | |||||
| return new LLamaWeights(weights, @params.Encoding); | |||||
| return new LLamaWeights(weights); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -73,7 +83,7 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="params"></param> | /// <param name="params"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public LLamaContext CreateContext(IModelParams @params) | |||||
| public LLamaContext CreateContext(IContextParams @params) | |||||
| { | { | ||||
| return new LLamaContext(this, @params); | return new LLamaContext(this, @params); | ||||
| } | } | ||||
| @@ -0,0 +1,106 @@ | |||||
| using System; | |||||
| namespace LLama.Native; | |||||
| using llama_token = Int32; | |||||
| public sealed class LLamaBatchSafeHandle | |||||
| : SafeLLamaHandleBase | |||||
| { | |||||
| private readonly int _embd; | |||||
| public LLamaNativeBatch Batch { get; private set; } | |||||
| /// <summary> | |||||
| /// the token ids of the input (used when embd is NULL) | |||||
| /// </summary> | |||||
| public Span<llama_token> Token | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| if (_embd != 0) | |||||
| return new Span<int>(null, 0); | |||||
| else | |||||
| return new Span<int>(Batch.token, Batch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | |||||
| /// </summary> | |||||
| public Span<llama_token> Embed | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) | |||||
| /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||||
| if (_embd != 0) | |||||
| return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd); | |||||
| else | |||||
| return new Span<llama_token>(null, 0); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// the positions of the respective token in the sequence | |||||
| /// </summary> | |||||
| public Span<LLamaPos> Pos | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new Span<LLamaPos>(Batch.pos, Batch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// the sequence to which the respective token belongs | |||||
| /// </summary> | |||||
| public Span<LLamaSeqId> Sequence_ID | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// if zero, the logits for the respective token will not be output | |||||
| /// </summary> | |||||
| public Span<byte> Logits | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new Span<byte>(Batch.logits, Batch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| public LLamaBatchSafeHandle(int n_tokens, int embd) | |||||
| : base((nint)1) | |||||
| { | |||||
| _embd = embd; | |||||
| Batch = NativeApi.llama_batch_init(n_tokens, embd); | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| NativeApi.llama_batch_free(Batch); | |||||
| Batch = default; | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| @@ -19,32 +19,27 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// RNG seed, -1 for random | /// RNG seed, -1 for random | ||||
| /// </summary> | /// </summary> | ||||
| public int seed; | |||||
| public uint seed; | |||||
| /// <summary> | /// <summary> | ||||
| /// text context | /// text context | ||||
| /// </summary> | /// </summary> | ||||
| public int n_ctx; | |||||
| public uint n_ctx; | |||||
| /// <summary> | /// <summary> | ||||
| /// prompt processing batch size | /// prompt processing batch size | ||||
| /// </summary> | /// </summary> | ||||
| public int n_batch; | |||||
| public uint n_batch; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of layers to store in VRAM | |||||
| /// number of threads to use for generation | |||||
| /// </summary> | /// </summary> | ||||
| public int n_gpu_layers; | |||||
| public uint n_threads; | |||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | |||||
| /// number of threads to use for batch processing | |||||
| /// </summary> | /// </summary> | ||||
| public int main_gpu; | |||||
| /// <summary> | |||||
| /// how to split layers across multiple GPUs | |||||
| /// </summary> | |||||
| public nint tensor_split; | |||||
| public uint n_threads_batch; | |||||
| /// <summary> | /// <summary> | ||||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | ||||
| @@ -58,26 +53,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public float rope_freq_scale; | public float rope_freq_scale; | ||||
| /// <summary> | |||||
| /// called with a progress value between 0 and 1, pass NULL to disable | |||||
| /// </summary> | |||||
| public IntPtr progress_callback; | |||||
| /// <summary> | |||||
| /// context pointer passed to the progress callback | |||||
| /// </summary> | |||||
| public IntPtr progress_callback_user_data; | |||||
| /// <summary> | |||||
| /// if true, reduce VRAM usage at the cost of performance | |||||
| /// </summary> | |||||
| public bool low_vram | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_low_vram); | |||||
| set => _low_vram = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _low_vram; | |||||
| /// <summary> | /// <summary> | ||||
| /// if true, use experimental mul_mat_q kernels | /// if true, use experimental mul_mat_q kernels | ||||
| /// </summary> | /// </summary> | ||||
| @@ -108,36 +83,6 @@ namespace LLama.Native | |||||
| } | } | ||||
| private sbyte _logits_all; | private sbyte _logits_all; | ||||
| /// <summary> | |||||
| /// only load the vocabulary, no weights | |||||
| /// </summary> | |||||
| public bool vocab_only | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_vocab_only); | |||||
| set => _vocab_only = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _vocab_only; | |||||
| /// <summary> | |||||
| /// use mmap if possible | |||||
| /// </summary> | |||||
| public bool use_mmap | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mmap); | |||||
| set => _use_mmap = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mmap; | |||||
| /// <summary> | |||||
| /// force system to keep model in RAM | |||||
| /// </summary> | |||||
| public bool use_mlock | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mlock); | |||||
| set => _use_mlock = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mlock; | |||||
| /// <summary> | /// <summary> | ||||
| /// embedding mode only | /// embedding mode only | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,67 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | |||||
| { | |||||
| /// <summary> | |||||
| /// A C# representation of the llama.cpp `llama_model_params` struct | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public unsafe struct LLamaModelParams | |||||
| { | |||||
| /// <summary> | |||||
| /// // number of layers to store in VRAM | |||||
| /// </summary> | |||||
| public int n_gpu_layers; | |||||
| /// <summary> | |||||
| /// // the GPU that is used for scratch and small tensors | |||||
| /// </summary> | |||||
| public int main_gpu; | |||||
| /// <summary> | |||||
| /// how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) | |||||
| /// </summary> | |||||
| public float* tensor_split; | |||||
| /// <summary> | |||||
| /// called with a progress value between 0 and 1, pass NULL to disable | |||||
| /// </summary> | |||||
| LlamaProgressCallback progress_callback; | |||||
| /// <summary> | |||||
| /// context pointer passed to the progress callback | |||||
| /// </summary> | |||||
| void* progress_callback_user_data; | |||||
| /// <summary> | |||||
| /// only load the vocabulary, no weights | |||||
| /// </summary> | |||||
| public bool vocab_only | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_vocab_only); | |||||
| set => _vocab_only = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _vocab_only; | |||||
| /// <summary> | |||||
| /// use mmap if possible | |||||
| /// </summary> | |||||
| public bool use_mmap | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mmap); | |||||
| set => _use_mmap = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mmap; | |||||
| /// <summary> | |||||
| /// force system to keep model in RAM | |||||
| /// </summary> | |||||
| public bool use_mlock | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mlock); | |||||
| set => _use_mlock = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mlock; | |||||
| } | |||||
| } | |||||
| @@ -36,5 +36,15 @@ namespace LLama.Native | |||||
| set => _quantize_output_tensor = Convert.ToSByte(value); | set => _quantize_output_tensor = Convert.ToSByte(value); | ||||
| } | } | ||||
| private sbyte _quantize_output_tensor; | private sbyte _quantize_output_tensor; | ||||
| /// <summary> | |||||
| /// only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored | |||||
| /// </summary> | |||||
| public bool only_copy | |||||
| { | |||||
| get => Convert.ToBoolean(_only_copy); | |||||
| set => _only_copy = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _only_copy; | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,45 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native; | |||||
| using llama_token = Int32; | |||||
| /// <summary> | |||||
| /// Input data for llama_decode | |||||
| /// A llama_batch object can contain input about one or many sequences | |||||
| /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public readonly unsafe struct LLamaNativeBatch | |||||
| { | |||||
| /// <summary> | |||||
| /// The number of items pointed at by pos, seq_id and logits. | |||||
| /// </summary> | |||||
| public readonly int n_tokens; | |||||
| /// <summary> | |||||
| /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | |||||
| /// </summary> | |||||
| public readonly llama_token* token; | |||||
| /// <summary> | |||||
| /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created | |||||
| /// </summary> | |||||
| public readonly float* embd; | |||||
| /// <summary> | |||||
| /// the positions of the respective token in the sequence | |||||
| /// </summary> | |||||
| public readonly LLamaPos* pos; | |||||
| /// <summary> | |||||
| /// the sequence to which the respective token belongs | |||||
| /// </summary> | |||||
| public readonly LLamaSeqId* seq_id; | |||||
| /// <summary> | |||||
| /// if zero, the logits for the respective token will not be output | |||||
| /// </summary> | |||||
| public readonly byte* logits; | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| namespace LLama.Native; | |||||
| public record struct LLamaPos | |||||
| { | |||||
| public int Value; | |||||
| public LLamaPos(int value) | |||||
| { | |||||
| Value = value; | |||||
| } | |||||
| public static explicit operator int(LLamaPos pos) => pos.Value; | |||||
| public static implicit operator LLamaPos(int value) => new(value); | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| namespace LLama.Native; | |||||
| public record struct LLamaSeqId | |||||
| { | |||||
| public int Value; | |||||
| public LLamaSeqId(int value) | |||||
| { | |||||
| Value = value; | |||||
| } | |||||
| public static explicit operator int(LLamaSeqId pos) => pos.Value; | |||||
| public static explicit operator LLamaSeqId(int value) => new(value); | |||||
| } | |||||
| @@ -2,7 +2,6 @@ | |||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Common; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| #pragma warning disable IDE1006 // Naming Styles | #pragma warning disable IDE1006 // Naming Styles | ||||
| @@ -110,6 +109,13 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_empty_call(); | public static extern bool llama_empty_call(); | ||||
| /// <summary> | |||||
| /// Create a LLamaModelParams with default values | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaModelParams llama_model_default_params(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a LLamaContextParams with default values | /// Create a LLamaContextParams with default values | ||||
| /// </summary> | /// </summary> | ||||
| @@ -138,18 +144,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_mlock_supported(); | public static extern bool llama_mlock_supported(); | ||||
| /// <summary> | |||||
| /// Export a static computation graph for context of 511 and batch size of 1 | |||||
| /// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these | |||||
| /// parameters here to keep things simple | |||||
| /// IMPORTANT: do not use for anything else other than debugging and testing! | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="fname"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname); | |||||
| /// <summary> | /// <summary> | ||||
| /// Various functions for loading a ggml llama model. | /// Various functions for loading a ggml llama model. | ||||
| /// Allocate (almost) all memory needed for the model. | /// Allocate (almost) all memory needed for the model. | ||||
| @@ -159,7 +153,7 @@ namespace LLama.Native | |||||
| /// <param name="params"></param> | /// <param name="params"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params); | |||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params); | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new llama_context with the given model. | /// Create a new llama_context with the given model. | ||||
| @@ -192,7 +186,7 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_free_model(IntPtr model); | public static extern void llama_free_model(IntPtr model); | ||||
| /// <summary> | /// <summary> | ||||
| /// Apply a LoRA adapter to a loaded model | /// Apply a LoRA adapter to a loaded model | ||||
| /// path_base_model is the path to a higher quality model to use as a base for | /// path_base_model is the path to a higher quality model to use as a base for | ||||
| @@ -202,19 +196,12 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model_ptr"></param> | /// <param name="model_ptr"></param> | ||||
| /// <param name="path_lora"></param> | /// <param name="path_lora"></param> | ||||
| /// <param name="scale"></param> | |||||
| /// <param name="path_base_model"></param> | /// <param name="path_base_model"></param> | ||||
| /// <param name="n_threads"></param> | /// <param name="n_threads"></param> | ||||
| /// <returns>Returns 0 on success</returns> | /// <returns>Returns 0 on success</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads); | |||||
| /// <summary> | |||||
| /// Returns the number of tokens in the KV cache | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx); | |||||
| public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets the current rng seed. | /// Sets the current rng seed. | ||||
| @@ -222,7 +209,7 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="seed"></param> | /// <param name="seed"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, int seed); | |||||
| public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the maximum size in bytes of the state (rng, logits, embedding | /// Returns the maximum size in bytes of the state (rng, logits, embedding | ||||
| @@ -243,21 +230,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); | public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); | ||||
| /// <summary> | |||||
| /// Copies the state to the specified destination address. | |||||
| /// Destination needs to have allocated enough memory (see llama_get_state_size) | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="dest"></param> | |||||
| /// <returns>the number of bytes copied</returns> | |||||
| public static ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte[] dest) | |||||
| { | |||||
| fixed (byte* dstPtr = &dest[0]) | |||||
| { | |||||
| return llama_copy_state_data(ctx, dstPtr); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Set the state reading from the specified address | /// Set the state reading from the specified address | ||||
| /// </summary> | /// </summary> | ||||
| @@ -267,20 +239,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); | public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); | ||||
| /// <summary> | |||||
| /// Set the state reading from the specified address | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="src"></param> | |||||
| /// <returns>the number of bytes read</returns> | |||||
| public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src) | |||||
| { | |||||
| fixed (byte* srcPtr = &src[0]) | |||||
| { | |||||
| return llama_set_state_data(ctx, srcPtr); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Load session file | /// Load session file | ||||
| /// </summary> | /// </summary> | ||||
| @@ -313,24 +271,9 @@ namespace LLama.Native | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_tokens"></param> | /// <param name="n_tokens"></param> | ||||
| /// <param name="n_past"></param> | /// <param name="n_past"></param> | ||||
| /// <param name="n_threads"></param> | |||||
| /// <returns>Returns 0 on success</returns> | /// <returns>Returns 0 on success</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads); | |||||
| /// <summary> | |||||
| /// Run the llama inference to obtain the logits and probabilities for the next token. | |||||
| /// tokens + n_tokens is the provided batch of new tokens to process | |||||
| /// n_past is the number of tokens to use from previous eval calls | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="n_tokens"></param> | |||||
| /// <param name="n_past"></param> | |||||
| /// <param name="n_threads"></param> | |||||
| /// <returns>Returns 0 on success</returns> | |||||
| [DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads); | |||||
| public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert the provided text into tokens. | /// Convert the provided text into tokens. | ||||
| @@ -341,10 +284,11 @@ namespace LLama.Native | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_max_tokens"></param> | /// <param name="n_max_tokens"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | ||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | /// Returns a negative number on failure - the number of tokens that would have been returned | ||||
| /// </returns> | /// </returns> | ||||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) | |||||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) | |||||
| { | { | ||||
| // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) | // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) | ||||
| var byteCount = encoding.GetByteCount(text); | var byteCount = encoding.GetByteCount(text); | ||||
| @@ -364,7 +308,7 @@ namespace LLama.Native | |||||
| // Do the actual tokenization | // Do the actual tokenization | ||||
| fixed (byte* arrayPtr = array) | fixed (byte* arrayPtr = array) | ||||
| fixed (llama_token* tokensPtr = tokens) | fixed (llama_token* tokensPtr = tokens) | ||||
| return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos); | |||||
| return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -372,28 +316,6 @@ namespace LLama.Native | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Convert the provided text into tokens. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="n_max_tokens"></param> | |||||
| /// <param name="add_bos"></param> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||||
| /// </returns> | |||||
| [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos); | |||||
| /// <summary> | |||||
| /// Get the number of tokens in the model vocabulary for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the size of the context window for the model for this context | /// Get the size of the context window for the model for this context | ||||
| /// </summary> | /// </summary> | ||||
| @@ -402,14 +324,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); | public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Get the dimension of embedding vectors from the model for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_embd(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Token logits obtained from the last call to llama_eval() | /// Token logits obtained from the last call to llama_eval() | ||||
| /// The logits for the last token are stored in the last row | /// The logits for the last token are stored in the last row | ||||
| @@ -423,22 +337,21 @@ namespace LLama.Native | |||||
| public static extern float* llama_get_logits(SafeLLamaContextHandle ctx); | public static extern float* llama_get_logits(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | /// <summary> | ||||
| /// Get the embeddings for the input | |||||
| /// shape: [n_embd] (1-dimensional) | |||||
| /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); | |||||
| public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Token Id -> String. Uses the vocabulary in the provided context | |||||
| /// Get the embeddings for the input | |||||
| /// shape: [n_embd] (1-dimensional) | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="token"></param> | |||||
| /// <returns>Pointer to a string.</returns> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token); | |||||
| public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the "Beginning of sentence" token | /// Get the "Beginning of sentence" token | ||||
| @@ -488,7 +401,7 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_n_vocab(SafeLlamaModelHandle model); | |||||
| public static extern int llama_n_vocab(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the size of the context window for the model | /// Get the size of the context window for the model | ||||
| @@ -496,7 +409,7 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_n_ctx(SafeLlamaModelHandle model); | |||||
| public static extern int llama_n_ctx_train(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the dimension of embedding vectors from this model | /// Get the dimension of embedding vectors from this model | ||||
| @@ -504,7 +417,23 @@ namespace LLama.Native | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_model_n_embd(SafeLlamaModelHandle model); | |||||
| public static extern int llama_n_embd(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Get the size of the model in bytes | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern ulong llama_model_size(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Get the number of parameters in this model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern ulong llama_model_n_params(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert a single token into text | /// Convert a single token into text | ||||
| @@ -515,21 +444,23 @@ namespace LLama.Native | |||||
| /// <param name="length">size of the buffer</param> | /// <param name="length">size of the buffer</param> | ||||
| /// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns> | /// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||||
| public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert text into tokens | /// Convert text into tokens | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="text_len"></param> | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_max_tokens"></param> | /// <param name="n_max_tokens"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | ||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | /// Returns a negative number on failure - the number of tokens that would have been returned | ||||
| /// </returns> | /// </returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); | |||||
| public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); | |||||
| /// <summary> | /// <summary> | ||||
| /// Register a callback to receive llama log messages | /// Register a callback to receive llama log messages | ||||
| @@ -537,5 +468,98 @@ namespace LLama.Native | |||||
| /// <param name="logCallback"></param> | /// <param name="logCallback"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_log_set(LLamaLogCallback logCallback); | public static extern void llama_log_set(LLamaLogCallback logCallback); | ||||
| } | |||||
| /// <summary> | |||||
| /// Remove all tokens data of cells in [c0, c1) | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="c0"></param> | |||||
| /// <param name="c1"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1); | |||||
| /// <summary> | |||||
| /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="seq"></param> | |||||
| /// <param name="p0"></param> | |||||
| /// <param name="p1"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1); | |||||
| /// <summary> | |||||
| /// Copy all tokens that belong to the specified sequence to another sequence | |||||
| /// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="src"></param> | |||||
| /// <param name="dest"></param> | |||||
| /// <param name="p0"></param> | |||||
| /// <param name="p1"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_kv_cache_seq_cp(SafeLLamaContextHandle ctx, LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1); | |||||
| /// <summary> | |||||
| /// Removes all tokens that do not belong to the specified sequence | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="seq"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_kv_cache_seq_keep(SafeLLamaContextHandle ctx, LLamaSeqId seq); | |||||
| /// <summary> | |||||
| /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) | |||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="seq"></param> | |||||
| /// <param name="p0"></param> | |||||
| /// <param name="p1"></param> | |||||
| /// <param name="delta"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); | |||||
| /// <summary> | |||||
| /// Allocates a batch of tokens on the heap | |||||
| /// The batch has to be freed with llama_batch_free() | |||||
| /// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) | |||||
| /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||||
| /// The rest of the llama_batch members are allocated with size n_tokens | |||||
| /// All members are left uninitialized | |||||
| /// </summary> | |||||
| /// <param name="n_tokens"></param> | |||||
| /// <param name="embd"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd); | |||||
| /// <summary> | |||||
| /// Frees a batch of tokens allocated with llama_batch_init() | |||||
| /// </summary> | |||||
| /// <param name="batch"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_batch_free(LLamaNativeBatch batch); | |||||
| /// <summary> | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="batch"></param> | |||||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||||
| /// - 0: success<br /> | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||||
| /// - < 0: error<br /> | |||||
| /// </returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); | |||||
| /// <summary> | |||||
| /// Set the number of threads used for decoding | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param> | |||||
| /// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); | |||||
| } | |||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -21,26 +22,13 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Total number of tokens in the context | /// Total number of tokens in the context | ||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize => ThrowIfDisposed().ContextSize; | |||||
| public int ContextSize => NativeApi.llama_n_ctx(this); | |||||
| /// <summary> | /// <summary> | ||||
| /// Dimension of embedding vectors | /// Dimension of embedding vectors | ||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; | public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; | ||||
| /// <summary> | |||||
| /// Get the number of tokens in the KV Cache for this context | |||||
| /// </summary> | |||||
| public int KVCacheTokenCount | |||||
| { | |||||
| get | |||||
| { | |||||
| if (IsClosed) | |||||
| throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); | |||||
| return NativeApi.llama_get_kv_cache_token_count(this); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the model which this context is using | /// Get the model which this context is using | ||||
| /// </summary> | /// </summary> | ||||
| @@ -64,17 +52,20 @@ namespace LLama.Native | |||||
| _model.DangerousAddRef(ref success); | _model.DangerousAddRef(ref success); | ||||
| if (!success) | if (!success) | ||||
| throw new RuntimeError("Failed to increment model refcount"); | throw new RuntimeError("Failed to increment model refcount"); | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| NativeApi.llama_free(DangerousGetHandle()); | |||||
| SetHandle(IntPtr.Zero); | |||||
| // Decrement refcount on model | // Decrement refcount on model | ||||
| _model?.DangerousRelease(); | _model?.DangerousRelease(); | ||||
| _model = null!; | _model = null!; | ||||
| NativeApi.llama_free(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -103,46 +94,38 @@ namespace LLama.Native | |||||
| return new(ctx_ptr, model); | return new(ctx_ptr, model); | ||||
| } | } | ||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new llama context with a clone of the current llama context state | |||||
| /// Token logits obtained from the last call to llama_eval() | |||||
| /// The logits for the last token are stored in the last row | |||||
| /// Can be mutated in order to change the probabilities of the next token.<br /> | |||||
| /// Rows: n_tokens<br /> | |||||
| /// Cols: n_vocab | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="lparams"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public SafeLLamaContextHandle Clone(LLamaContextParams lparams) | |||||
| public Span<float> GetLogits() | |||||
| { | { | ||||
| // Allocate space to read the state of the current context | |||||
| var stateSize = GetStateSize(); | |||||
| var stateMemory = Marshal.AllocHGlobal((nint)stateSize); | |||||
| try | |||||
| { | |||||
| // Copy state from this context into memory | |||||
| GetState(stateMemory, stateSize); | |||||
| // Create a new context | |||||
| var newCtx = Create(ModelHandle, lparams); | |||||
| // Copy state into new context | |||||
| newCtx.SetState(stateMemory); | |||||
| var model = ThrowIfDisposed(); | |||||
| return newCtx; | |||||
| } | |||||
| finally | |||||
| unsafe | |||||
| { | { | ||||
| Marshal.FreeHGlobal(stateMemory); | |||||
| var logits = NativeApi.llama_get_logits(this); | |||||
| return new Span<float>(logits, model.VocabCount); | |||||
| } | } | ||||
| } | } | ||||
| #endregion | |||||
| #region tokens | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert the given text into tokens | /// Convert the given text into tokens | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="text">The text to tokenize</param> | /// <param name="text">The text to tokenize</param> | ||||
| /// <param name="add_bos">Whether the "BOS" token should be added</param> | /// <param name="add_bos">Whether the "BOS" token should be added</param> | ||||
| /// <param name="encoding">Encoding to use for the text</param> | /// <param name="encoding">Encoding to use for the text</param> | ||||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public int[] Tokenize(string text, bool add_bos, Encoding encoding) | |||||
| public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) | |||||
| { | { | ||||
| ThrowIfDisposed(); | ThrowIfDisposed(); | ||||
| @@ -158,7 +141,7 @@ namespace LLama.Native | |||||
| try | try | ||||
| { | { | ||||
| // Do the actual conversion | // Do the actual conversion | ||||
| var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); | |||||
| var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special); | |||||
| if (n < 0) | if (n < 0) | ||||
| { | { | ||||
| throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | ||||
| @@ -177,25 +160,6 @@ namespace LLama.Native | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Token logits obtained from the last call to llama_eval() | |||||
| /// The logits for the last token are stored in the last row | |||||
| /// Can be mutated in order to change the probabilities of the next token.<br /> | |||||
| /// Rows: n_tokens<br /> | |||||
| /// Cols: n_vocab | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public Span<float> GetLogits() | |||||
| { | |||||
| var model = ThrowIfDisposed(); | |||||
| unsafe | |||||
| { | |||||
| var logits = NativeApi.llama_get_logits(this); | |||||
| return new Span<float>(logits, model.VocabCount); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert a token into a string | /// Convert a token into a string | ||||
| /// </summary> | /// </summary> | ||||
| @@ -228,25 +192,31 @@ namespace LLama.Native | |||||
| { | { | ||||
| return ThrowIfDisposed().TokenToSpan(token, dest); | return ThrowIfDisposed().TokenToSpan(token, dest); | ||||
| } | } | ||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// Run the llama inference to obtain the logits and probabilities for the next token. | /// Run the llama inference to obtain the logits and probabilities for the next token. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tokens">The provided batch of new tokens to process</param> | /// <param name="tokens">The provided batch of new tokens to process</param> | ||||
| /// <param name="n_past">the number of tokens to use from previous eval calls</param> | /// <param name="n_past">the number of tokens to use from previous eval calls</param> | ||||
| /// <param name="n_threads"></param> | |||||
| /// <returns>Returns true on success</returns> | /// <returns>Returns true on success</returns> | ||||
| public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads) | |||||
| public bool Eval(ReadOnlySpan<int> tokens, int n_past) | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| fixed (int* pinned = tokens) | fixed (int* pinned = tokens) | ||||
| { | { | ||||
| return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; | |||||
| var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); | |||||
| return ret == 0; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public int Decode(LLamaBatchSafeHandle batch) | |||||
| { | |||||
| return NativeApi.llama_decode(this, batch.Batch); | |||||
| } | |||||
| #region state | #region state | ||||
| /// <summary> | /// <summary> | ||||
| /// Get the size of the state, when saved as bytes | /// Get the size of the state, when saved as bytes | ||||
| @@ -29,18 +29,30 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingSize { get; } | public int EmbeddingSize { get; } | ||||
| /// <summary> | |||||
| /// Get the size of this model in bytes | |||||
| /// </summary> | |||||
| public ulong SizeInBytes { get; } | |||||
| /// <summary> | |||||
| /// Get the number of parameters in this model | |||||
| /// </summary> | |||||
| public ulong ParameterCount { get; } | |||||
| internal SafeLlamaModelHandle(IntPtr handle) | internal SafeLlamaModelHandle(IntPtr handle) | ||||
| : base(handle) | : base(handle) | ||||
| { | { | ||||
| VocabCount = NativeApi.llama_model_n_vocab(this); | |||||
| ContextSize = NativeApi.llama_model_n_ctx(this); | |||||
| EmbeddingSize = NativeApi.llama_model_n_embd(this); | |||||
| VocabCount = NativeApi.llama_n_vocab(this); | |||||
| ContextSize = NativeApi.llama_n_ctx_train(this); | |||||
| EmbeddingSize = NativeApi.llama_n_embd(this); | |||||
| SizeInBytes = NativeApi.llama_model_size(this); | |||||
| ParameterCount = NativeApi.llama_model_n_params(this); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| NativeApi.llama_free_model(handle); | |||||
| NativeApi.llama_free_model(DangerousGetHandle()); | |||||
| SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -52,7 +64,7 @@ namespace LLama.Native | |||||
| /// <param name="lparams"></param> | /// <param name="lparams"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams) | |||||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) | |||||
| { | { | ||||
| var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | ||||
| if (model_ptr == IntPtr.Zero) | if (model_ptr == IntPtr.Zero) | ||||
| @@ -62,21 +74,24 @@ namespace LLama.Native | |||||
| } | } | ||||
| #region LoRA | #region LoRA | ||||
| /// <summary> | /// <summary> | ||||
| /// Apply a LoRA adapter to a loaded model | /// Apply a LoRA adapter to a loaded model | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="lora"></param> | /// <param name="lora"></param> | ||||
| /// <param name="scale"></param> | |||||
| /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the | /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the | ||||
| /// adapter. Can be NULL to use the current loaded model.</param> | /// adapter. Can be NULL to use the current loaded model.</param> | ||||
| /// <param name="threads"></param> | /// <param name="threads"></param> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1) | |||||
| public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null) | |||||
| { | { | ||||
| var err = NativeApi.llama_model_apply_lora_from_file( | var err = NativeApi.llama_model_apply_lora_from_file( | ||||
| this, | this, | ||||
| lora, | lora, | ||||
| scale, | |||||
| string.IsNullOrEmpty(modelBase) ? null : modelBase, | string.IsNullOrEmpty(modelBase) ? null : modelBase, | ||||
| threads | |||||
| (int?)threads ?? -1 | |||||
| ); | ); | ||||
| if (err != 0) | if (err != 0) | ||||
| @@ -97,7 +112,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| fixed (byte* destPtr = dest) | fixed (byte* destPtr = dest) | ||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length); | |||||
| var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length); | |||||
| return Math.Abs(length); | return Math.Abs(length); | ||||
| } | } | ||||
| } | } | ||||
| @@ -113,7 +128,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); | |||||
| var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); | |||||
| if (length == 0) | if (length == 0) | ||||
| return ""; | return ""; | ||||
| @@ -121,7 +136,7 @@ namespace LLama.Native | |||||
| fixed (byte* bytePtr = bytes) | fixed (byte* bytePtr = bytes) | ||||
| { | { | ||||
| var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); | |||||
| var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); | |||||
| Debug.Assert(written == bytes.Length); | Debug.Assert(written == bytes.Length); | ||||
| return encoding.GetString(bytePtr, bytes.Length); | return encoding.GetString(bytePtr, bytes.Length); | ||||
| @@ -139,7 +154,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); | |||||
| var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); | |||||
| if (length == 0) | if (length == 0) | ||||
| return; | return; | ||||
| @@ -147,7 +162,7 @@ namespace LLama.Native | |||||
| fixed (byte* bytePtr = bytes) | fixed (byte* bytePtr = bytes) | ||||
| { | { | ||||
| // Decode into bytes | // Decode into bytes | ||||
| var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); | |||||
| var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); | |||||
| Debug.Assert(written == bytes.Length); | Debug.Assert(written == bytes.Length); | ||||
| // Decode into chars | // Decode into chars | ||||
| @@ -256,8 +271,9 @@ namespace LLama.Native | |||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| /// <param name="encoding"></param> | /// <param name="encoding"></param> | ||||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public int[] Tokenize(string text, bool add_bos, Encoding encoding) | |||||
| public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) | |||||
| { | { | ||||
| // Convert string to bytes, adding one extra byte to the end (null terminator) | // Convert string to bytes, adding one extra byte to the end (null terminator) | ||||
| var bytesCount = encoding.GetByteCount(text); | var bytesCount = encoding.GetByteCount(text); | ||||
| @@ -276,13 +292,13 @@ namespace LLama.Native | |||||
| fixed (byte* bytesPtr = &bytes[0]) | fixed (byte* bytesPtr = &bytes[0]) | ||||
| { | { | ||||
| // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | ||||
| var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos); | |||||
| var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special); | |||||
| // Tokenize again, this time outputting into an array of exactly the right size | // Tokenize again, this time outputting into an array of exactly the right size | ||||
| var tokens = new int[count]; | var tokens = new int[count]; | ||||
| fixed (int* tokensPtr = &tokens[0]) | fixed (int* tokensPtr = &tokens[0]) | ||||
| { | { | ||||
| NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); | |||||
| NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); | |||||
| return tokens; | return tokens; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,108 +0,0 @@ | |||||
| using LLama.Abstractions; | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using LLama.Extensions; | |||||
| namespace LLama | |||||
| { | |||||
| using llama_token = Int32; | |||||
| /// <summary> | |||||
| /// Assorted llama utilities | |||||
| /// </summary> | |||||
| public static class Utils | |||||
| { | |||||
| [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| using (@params.ToLlamaContextParams(out var lparams)) | |||||
| return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); | |||||
| } | |||||
| [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | |||||
| return ctx.Tokenize(text, add_bos, encoding); | |||||
| } | |||||
| [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | |||||
| if (length != ctx.VocabCount) | |||||
| throw new ArgumentException("length must be the VocabSize"); | |||||
| return ctx.GetLogits(); | |||||
| } | |||||
| [Obsolete("Use SafeLLamaContextHandle Eval method instead")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | |||||
| var slice = tokens.AsSpan().Slice(startIndex, n_tokens); | |||||
| return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; | |||||
| } | |||||
| [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | |||||
| return ctx.TokenToString(token, encoding); | |||||
| } | |||||
| [Obsolete("No longer used internally by LlamaSharp")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static string PtrToString(IntPtr ptr, Encoding encoding) | |||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | |||||
| #if NET6_0_OR_GREATER | |||||
| // ReSharper disable once PossibleUnintendedReferenceComparison | |||||
| if(encoding == Encoding.UTF8) | |||||
| { | |||||
| return Marshal.PtrToStringUTF8(ptr)!; | |||||
| } | |||||
| // ReSharper disable once PossibleUnintendedReferenceComparison | |||||
| else if(encoding == Encoding.Unicode) | |||||
| { | |||||
| return Marshal.PtrToStringUni(ptr)!; | |||||
| } | |||||
| else | |||||
| { | |||||
| return Marshal.PtrToStringAuto(ptr)!; | |||||
| } | |||||
| #else | |||||
| unsafe | |||||
| { | |||||
| byte* tp = (byte*)ptr.ToPointer(); | |||||
| List<byte> bytes = new(); | |||||
| while (true) | |||||
| { | |||||
| byte c = *tp++; | |||||
| if (c == '\0') | |||||
| { | |||||
| break; | |||||
| } | |||||
| else | |||||
| { | |||||
| bytes.Add(c); | |||||
| } | |||||
| } | |||||
| return encoding.GetString(bytes.ToArray()); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||