LLamaToken Structtags/v0.10.0
| @@ -1,6 +1,5 @@ | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| @@ -94,7 +93,7 @@ public class BatchedDecoding | |||
| var n_cur = batch.NativeBatch.n_tokens; | |||
| var n_decode = 0; | |||
| var streams = new List<int>[n_parallel]; | |||
| var streams = new List<LLamaToken>[n_parallel]; | |||
| for (var i = 0; i < n_parallel; i++) | |||
| streams[i] = new(); | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| namespace LLama.Unittest | |||
| { | |||
| @@ -37,7 +38,7 @@ namespace LLama.Unittest | |||
| { | |||
| var tokens = _context.Tokenize("The quick brown fox", true); | |||
| Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens); | |||
| Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens); | |||
| } | |||
| [Fact] | |||
| @@ -45,7 +46,7 @@ namespace LLama.Unittest | |||
| { | |||
| var tokens = _context.Tokenize("The quick brown fox", false); | |||
| Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens); | |||
| Assert.Equal(new LLamaToken[] { 450, 4996, 17354, 1701, 29916 }, tokens); | |||
| } | |||
| [Fact] | |||
| @@ -53,7 +54,7 @@ namespace LLama.Unittest | |||
| { | |||
| var tokens = _context.Tokenize("", false); | |||
| Assert.Equal(Array.Empty<int>(), tokens); | |||
| Assert.Equal(Array.Empty<LLamaToken>(), tokens); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,7 +17,7 @@ namespace LLama.Web.Common | |||
| public int MaxTokens { get; set; } = -1; | |||
| /// <inheritdoc /> | |||
| public Dictionary<int, float>? LogitBias { get; set; } = null; | |||
| public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null; | |||
| /// <inheritdoc /> | |||
| public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| @@ -24,7 +24,7 @@ namespace LLama.Abstractions | |||
| /// <summary> | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| public Dictionary<int, float>? LogitBias { get; set; } | |||
| public Dictionary<LLamaToken, float>? LogitBias { get; set; } | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| @@ -6,8 +6,6 @@ using LLama.Sampling; | |||
| namespace LLama.Common | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// The paramters used for inference. | |||
| /// </summary> | |||
| @@ -28,7 +26,7 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| public Dictionary<llama_token, float>? LogitBias { get; set; } = null; | |||
| public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null; | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| @@ -38,7 +38,7 @@ namespace LLama.Extensions | |||
| /// <returns></returns> | |||
| [Obsolete("Use an Antiprompt processor instead")] | |||
| internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) | |||
| where TTokens : IReadOnlyList<int> | |||
| where TTokens : IReadOnlyList<LLamaToken> | |||
| where TQueries : IReadOnlyList<string> | |||
| { | |||
| if (queries == null || queries.Count == 0 || tokens.Count == 0) | |||
| @@ -79,7 +79,7 @@ namespace LLama.Extensions | |||
| /// <returns></returns> | |||
| [Obsolete("Use an Antiprompt processor instead")] | |||
| internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding) | |||
| where TTokens : IReadOnlyList<int> | |||
| where TTokens : IReadOnlyList<LLamaToken> | |||
| { | |||
| if (queries == null || queries.Count == 0 || tokens.Count == 0) | |||
| return false; | |||
| @@ -15,8 +15,6 @@ using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// A llama_context, which holds all the context required to interact with a model | |||
| /// </summary> | |||
| @@ -93,7 +91,7 @@ namespace LLama | |||
| /// <param name="addBos">Whether to add a bos to the text.</param> | |||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param> | |||
| /// <returns></returns> | |||
| public llama_token[] Tokenize(string text, bool addBos = true, bool special = false) | |||
| public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false) | |||
| { | |||
| return NativeHandle.Tokenize(text, addBos, special, Encoding); | |||
| } | |||
| @@ -104,7 +102,7 @@ namespace LLama | |||
| /// <param name="tokens"></param> | |||
| /// <returns></returns> | |||
| [Obsolete("Use a `StreamingTokenDecoder` instead")] | |||
| public string DeTokenize(IReadOnlyList<llama_token> tokens) | |||
| public string DeTokenize(IReadOnlyList<LLamaToken> tokens) | |||
| { | |||
| // Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder! | |||
| // It should be kept around for the entire time you are decoding one stream of tokens. | |||
| @@ -219,7 +217,7 @@ namespace LLama | |||
| /// <param name="pipeline">The pipeline to use to process the logits and to select a token</param> | |||
| /// <param name="lastTokens">The tokens recently returned from the model</param> | |||
| /// <returns>The selected token</returns> | |||
| public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens) | |||
| public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||
| } | |||
| @@ -240,11 +238,11 @@ namespace LLama | |||
| /// <param name="grammar"></param> | |||
| /// <param name="minP"></param> | |||
| /// <returns></returns> | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, | |||
| float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, | |||
| SafeLLamaGrammarHandle? grammar, float minP) | |||
| public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, | |||
| float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, | |||
| SafeLLamaGrammarHandle? grammar, float minP) | |||
| { | |||
| llama_token id; | |||
| LLamaToken id; | |||
| if (grammar != null) | |||
| { | |||
| @@ -301,7 +299,7 @@ namespace LLama | |||
| /// <param name="alphaPresence"></param> | |||
| /// <param name="penalizeNL"></param> | |||
| /// <returns></returns> | |||
| public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null, | |||
| public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null, | |||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||
| bool penalizeNL = true) | |||
| { | |||
| @@ -311,12 +309,12 @@ namespace LLama | |||
| if (logitBias is not null) | |||
| { | |||
| foreach (var (key, value) in logitBias) | |||
| logits[key] += value; | |||
| logits[(int)key] += value; | |||
| } | |||
| // Save the newline logit value | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||
| var nl_logit = logits[nl_token]; | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||
| var nl_logit = logits[(int)nl_token]; | |||
| // Convert logits into token candidates | |||
| var candidates_p = LLamaTokenDataArray.Create(logits); | |||
| @@ -353,7 +351,7 @@ namespace LLama | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public int Eval(llama_token[] tokens, int pastTokensCount) | |||
| public int Eval(LLamaToken[] tokens, int pastTokensCount) | |||
| { | |||
| return Eval(tokens.AsSpan(), pastTokensCount); | |||
| } | |||
| @@ -366,7 +364,7 @@ namespace LLama | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public int Eval(List<llama_token> tokens, int pastTokensCount) | |||
| public int Eval(List<LLamaToken> tokens, int pastTokensCount) | |||
| { | |||
| #if NET5_0_OR_GREATER | |||
| var span = CollectionsMarshal.AsSpan(tokens); | |||
| @@ -376,7 +374,7 @@ namespace LLama | |||
| // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | |||
| // avoid the copying. | |||
| var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count); | |||
| var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count); | |||
| try | |||
| { | |||
| tokens.CopyTo(rented, 0); | |||
| @@ -384,7 +382,7 @@ namespace LLama | |||
| } | |||
| finally | |||
| { | |||
| System.Buffers.ArrayPool<llama_token>.Shared.Return(rented); | |||
| System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -397,7 +395,7 @@ namespace LLama | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount) | |||
| public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount) | |||
| { | |||
| return Eval(tokens.Span, pastTokensCount); | |||
| } | |||
| @@ -410,7 +408,7 @@ namespace LLama | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount) | |||
| public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount) | |||
| { | |||
| var total = tokens.Length; | |||
| for(var i = 0; i < total; i += (int)Params.BatchSize) | |||
| @@ -14,7 +14,6 @@ using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// The base class for stateful LLama executors. | |||
| /// </summary> | |||
| @@ -47,19 +46,19 @@ namespace LLama | |||
| /// <summary> | |||
| /// A container of the tokens to be processed and after processed. | |||
| /// </summary> | |||
| protected List<llama_token> _embeds = new(); // embd | |||
| protected List<LLamaToken> _embeds = new(); // embd | |||
| /// <summary> | |||
| /// A container for the tokens of input. | |||
| /// </summary> | |||
| protected List<llama_token> _embed_inps = new(); | |||
| protected List<LLamaToken> _embed_inps = new(); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| protected List<llama_token> _session_tokens = new(); | |||
| protected List<LLamaToken> _session_tokens = new(); | |||
| /// <summary> | |||
| /// The last tokens generated by the model. | |||
| /// </summary> | |||
| protected FixedSizeQueue<llama_token> _last_n_tokens; | |||
| protected FixedSizeQueue<LLamaToken> _last_n_tokens; | |||
| /// <summary> | |||
| /// The context used by the executor. | |||
| /// </summary> | |||
| @@ -84,7 +83,7 @@ namespace LLama | |||
| _pastTokensCount = 0; | |||
| _consumedTokensCount = 0; | |||
| _n_session_consumed = 0; | |||
| _last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize); | |||
| _last_n_tokens = new FixedSizeQueue<LLamaToken>(Context.ContextSize); | |||
| _decoder = new StreamingTokenDecoder(context); | |||
| } | |||
| @@ -105,7 +104,7 @@ namespace LLama | |||
| if (File.Exists(filename)) | |||
| { | |||
| _logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}"); | |||
| var session_tokens = new llama_token[Context.ContextSize]; | |||
| var session_tokens = new LLamaToken[Context.ContextSize]; | |||
| if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out)) | |||
| { | |||
| _logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}"); | |||
| @@ -361,16 +360,16 @@ namespace LLama | |||
| public string? SessionFilePath { get; set; } | |||
| [JsonPropertyName("embd")] | |||
| public List<llama_token> Embeds { get; set; } | |||
| public List<LLamaToken> Embeds { get; set; } | |||
| [JsonPropertyName("embd_inps")] | |||
| public List<llama_token> EmbedInps { get; set; } | |||
| public List<LLamaToken> EmbedInps { get; set; } | |||
| [JsonPropertyName("session_tokens")] | |||
| public List<llama_token> SessionTokens { get; set; } | |||
| public List<LLamaToken> SessionTokens { get; set; } | |||
| [JsonPropertyName("last_n_tokens")] | |||
| public llama_token[] LastTokens { get; set; } | |||
| public LLamaToken[] LastTokens { get; set; } | |||
| [JsonPropertyName("last_tokens_maximum_count")] | |||
| public int LastTokensCapacity { get; set; } | |||
| @@ -13,7 +13,6 @@ using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// The LLama executor for instruct mode. | |||
| /// </summary> | |||
| @@ -22,8 +21,8 @@ namespace LLama | |||
| { | |||
| private bool _is_prompt_run = true; | |||
| private readonly string _instructionPrefix; | |||
| private llama_token[] _inp_pfx; | |||
| private llama_token[] _inp_sfx; | |||
| private LLamaToken[] _inp_pfx; | |||
| private LLamaToken[] _inp_sfx; | |||
| /// <summary> | |||
| /// | |||
| @@ -75,7 +74,7 @@ namespace LLama | |||
| _is_prompt_run = state.IsPromptRun; | |||
| _consumedTokensCount = state.ConsumedTokensCount; | |||
| _embeds = state.Embeds; | |||
| _last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens); | |||
| _last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens); | |||
| _inp_pfx = state.InputPrefixTokens; | |||
| _inp_sfx = state.InputSuffixTokens; | |||
| _n_matching_session_tokens = state.MatchingSessionTokensCount; | |||
| @@ -210,7 +209,7 @@ namespace LLama | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| llama_token id; | |||
| LLamaToken id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||
| @@ -266,12 +265,12 @@ namespace LLama | |||
| /// Instruction prefix tokens. | |||
| /// </summary> | |||
| [JsonPropertyName("inp_pfx")] | |||
| public llama_token[] InputPrefixTokens { get; set; } | |||
| public LLamaToken[] InputPrefixTokens { get; set; } | |||
| /// <summary> | |||
| /// Instruction suffix tokens. | |||
| /// </summary> | |||
| [JsonPropertyName("inp_sfx")] | |||
| public llama_token[] InputSuffixTokens { get; set; } | |||
| public LLamaToken[] InputSuffixTokens { get; set; } | |||
| } | |||
| } | |||
| } | |||
| @@ -13,14 +13,13 @@ using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// The LLama executor for interactive mode. | |||
| /// </summary> | |||
| public class InteractiveExecutor : StatefulExecutorBase | |||
| { | |||
| private bool _is_prompt_run = true; | |||
| private readonly llama_token _llama_token_newline; | |||
| private readonly LLamaToken _llama_token_newline; | |||
| /// <summary> | |||
| /// | |||
| @@ -63,7 +62,7 @@ namespace LLama | |||
| _is_prompt_run = state.IsPromptRun; | |||
| _consumedTokensCount = state.ConsumedTokensCount; | |||
| _embeds = state.Embeds; | |||
| _last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens); | |||
| _last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens); | |||
| _n_matching_session_tokens = state.MatchingSessionTokensCount; | |||
| _pastTokensCount = state.PastTokensCount; | |||
| _pathSession = state.SessionFilePath; | |||
| @@ -189,7 +188,7 @@ namespace LLama | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| llama_token id; | |||
| LLamaToken id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||
| @@ -12,8 +12,6 @@ using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// This executor infer the input as one-time job. Previous inputs won't impact on the | |||
| /// response to current input. | |||
| @@ -71,9 +69,9 @@ namespace LLama | |||
| // Keep track of the last N tokens emitted | |||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | |||
| var lastTokens = new List<llama_token>(repeat_last_n); | |||
| var lastTokens = new List<LLamaToken>(repeat_last_n); | |||
| for (var i = 0; i < repeat_last_n; i++) | |||
| lastTokens.Add(0); | |||
| lastTokens.Add((LLamaToken)0); | |||
| // Tokenize the prompt | |||
| var tokens = Context.Tokenize(prompt).ToList(); | |||
| @@ -89,7 +87,7 @@ namespace LLama | |||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||
| { | |||
| llama_token id; | |||
| LLamaToken id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||
| @@ -42,17 +42,17 @@ namespace LLama | |||
| /// <summary> | |||
| /// Get the newline token for this model | |||
| /// </summary> | |||
| public int NewlineToken => NativeApi.llama_token_nl(NativeHandle); | |||
| public LLamaToken NewlineToken => NativeApi.llama_token_nl(NativeHandle); | |||
| /// <summary> | |||
| /// Get the "end of sentence" token for this model | |||
| /// </summary> | |||
| public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); | |||
| public LLamaToken EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); | |||
| /// <summary> | |||
| /// Get the "beginning of sentence" token for this model | |||
| /// </summary> | |||
| public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); | |||
| public LLamaToken BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| @@ -2,8 +2,6 @@ | |||
| namespace LLama.Native; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. | |||
| /// </summary> | |||
| @@ -20,16 +18,16 @@ public sealed class LLamaBatchSafeHandle | |||
| /// <summary> | |||
| /// the token ids of the input (used when embd is NULL) | |||
| /// </summary> | |||
| public Span<llama_token> Token | |||
| public Span<LLamaToken> Token | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| if (_embd != 0) | |||
| return new Span<int>(null, 0); | |||
| return new Span<LLamaToken>(null, 0); | |||
| else | |||
| return new Span<int>(NativeBatch.token, NativeBatch.n_tokens); | |||
| return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| @@ -37,7 +35,7 @@ public sealed class LLamaBatchSafeHandle | |||
| /// <summary> | |||
| /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | |||
| /// </summary> | |||
| public Span<llama_token> Embed | |||
| public Span<LLamaToken> Embed | |||
| { | |||
| get | |||
| { | |||
| @@ -47,9 +45,9 @@ public sealed class LLamaBatchSafeHandle | |||
| // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||
| if (_embd != 0) | |||
| return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd); | |||
| return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd); | |||
| else | |||
| return new Span<llama_token>(null, 0); | |||
| return new Span<LLamaToken>(null, 0); | |||
| } | |||
| } | |||
| } | |||
| @@ -133,7 +131,7 @@ public sealed class LLamaBatchSafeHandle | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 | |||
| /// </summary> | |||
| public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| { | |||
| unsafe | |||
| { | |||
| @@ -3,15 +3,13 @@ using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Information about a single beam in a beam search | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaBeamView | |||
| { | |||
| private unsafe llama_token* tokens; | |||
| private unsafe LLamaToken* tokens; | |||
| private nint n_tokens; | |||
| /// <summary> | |||
| @@ -27,7 +25,7 @@ public struct LLamaBeamView | |||
| /// <summary> | |||
| /// Tokens in this beam | |||
| /// </summary> | |||
| public readonly Span<llama_token> Tokens | |||
| public readonly Span<LLamaToken> Tokens | |||
| { | |||
| get | |||
| { | |||
| @@ -35,7 +33,7 @@ public struct LLamaBeamView | |||
| { | |||
| if (n_tokens > int.MaxValue) | |||
| throw new InvalidOperationException("More than 2147483647 tokens is not supported"); | |||
| return new Span<llama_token>(tokens, (int)n_tokens); | |||
| return new Span<LLamaToken>(tokens, (int)n_tokens); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,10 +1,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Input data for llama_decode | |||
| /// A llama_batch object can contain input about one or many sequences | |||
| @@ -21,7 +18,7 @@ public unsafe struct LLamaNativeBatch | |||
| /// <summary> | |||
| /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | |||
| /// </summary> | |||
| public llama_token* token; | |||
| public LLamaToken* token; | |||
| /// <summary> | |||
| /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created | |||
| @@ -0,0 +1,38 @@ | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// A single token | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public readonly record struct LLamaToken | |||
| { | |||
| /// <summary> | |||
| /// The raw value | |||
| /// </summary> | |||
| private readonly int Value; | |||
| /// <summary> | |||
| /// Create a new LLamaToken | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| private LLamaToken(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| /// <summary> | |||
| /// Convert a LLamaToken into an integer (extract the raw value) | |||
| /// </summary> | |||
| /// <param name="pos"></param> | |||
| /// <returns></returns> | |||
| public static explicit operator int(LLamaToken pos) => pos.Value; | |||
| /// <summary> | |||
| /// Convert an integer into a LLamaToken | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| /// <returns></returns> | |||
| public static implicit operator LLamaToken(int value) => new(value); | |||
| } | |||
| @@ -11,7 +11,7 @@ public struct LLamaTokenData | |||
| /// <summary> | |||
| /// token id | |||
| /// </summary> | |||
| public int id; | |||
| public LLamaToken id; | |||
| /// <summary> | |||
| /// log-odds of the token | |||
| @@ -29,7 +29,7 @@ public struct LLamaTokenData | |||
| /// <param name="id"></param> | |||
| /// <param name="logit"></param> | |||
| /// <param name="p"></param> | |||
| public LLamaTokenData(int id, float logit, float p) | |||
| public LLamaTokenData(LLamaToken id, float logit, float p) | |||
| { | |||
| this.id = id; | |||
| this.logit = logit; | |||
| @@ -2,8 +2,6 @@ | |||
| using System.Buffers; | |||
| using System.Runtime.InteropServices; | |||
| using llama_token = System.Int32; | |||
| namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| @@ -41,7 +39,7 @@ namespace LLama.Native | |||
| { | |||
| var candidates = new LLamaTokenData[logits.Length]; | |||
| for (var token_id = 0; token_id < logits.Length; token_id++) | |||
| candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); | |||
| candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f); | |||
| return new LLamaTokenDataArray(candidates); | |||
| } | |||
| @@ -50,7 +48,7 @@ namespace LLama.Native | |||
| /// Overwrite the logit values for all given tokens | |||
| /// </summary> | |||
| /// <param name="values">tuples of token and logit value to overwrite</param> | |||
| public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) | |||
| public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values) | |||
| { | |||
| if (values.Length == 0) | |||
| return; | |||
| @@ -172,13 +170,13 @@ namespace LLama.Native | |||
| /// <param name="penalty_repeat"></param> | |||
| /// <param name="penalty_freq"></param> | |||
| /// <param name="penalty_present"></param> | |||
| public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||
| public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLamaToken> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||
| { | |||
| unsafe | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| fixed (int* last_tokens_handle = last_tokens) | |||
| fixed (LLamaToken* last_tokens_handle = last_tokens) | |||
| { | |||
| NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||
| sorted = st.sorted; | |||
| @@ -220,7 +218,7 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| public int SampleToken(SafeLLamaContextHandle context) | |||
| public LLamaToken SampleToken(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| @@ -235,7 +233,7 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| public int SampleTokenGreedy(SafeLLamaContextHandle context) | |||
| public LLamaToken SampleTokenGreedy(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| @@ -254,7 +252,7 @@ namespace LLama.Native | |||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) | |||
| public LLamaToken SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| @@ -272,7 +270,7 @@ namespace LLama.Native | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) | |||
| public LLamaToken SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| @@ -3,8 +3,6 @@ using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| @@ -48,6 +46,6 @@ namespace LLama.Native | |||
| /// <param name="grammar"></param> | |||
| /// <param name="token"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token); | |||
| public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, LLamaToken token); | |||
| } | |||
| } | |||
| @@ -1,10 +1,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| @@ -21,7 +18,7 @@ namespace LLama.Native | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, | |||
| ref LLamaTokenDataArrayNative candidates, | |||
| llama_token* last_tokens, ulong last_tokens_size, | |||
| LLamaToken* last_tokens, ulong last_tokens_size, | |||
| float penalty_repeat, | |||
| float penalty_freq, | |||
| float penalty_present); | |||
| @@ -115,7 +112,7 @@ namespace LLama.Native | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); | |||
| public static extern LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); | |||
| /// <summary> | |||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||
| @@ -127,7 +124,7 @@ namespace LLama.Native | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); | |||
| public static extern LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); | |||
| /// <summary> | |||
| /// Selects the token with the highest probability. | |||
| @@ -136,7 +133,7 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||
| public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||
| /// <summary> | |||
| /// Randomly selects a token from the candidates based on their probabilities. | |||
| @@ -145,6 +142,6 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||
| public static extern LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||
| } | |||
| } | |||
| @@ -7,8 +7,6 @@ using System.Text; | |||
| namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Callback from llama.cpp with log messages | |||
| /// </summary> | |||
| @@ -141,7 +139,7 @@ namespace LLama.Native | |||
| /// <param name="n_token_count_out"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); | |||
| public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); | |||
| /// <summary> | |||
| /// Save session file | |||
| @@ -152,7 +150,7 @@ namespace LLama.Native | |||
| /// <param name="n_token_count"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens, ulong n_token_count); | |||
| public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); | |||
| /// <summary> | |||
| /// Run the llama inference to obtain the logits and probabilities for the next token. | |||
| @@ -166,7 +164,7 @@ namespace LLama.Native | |||
| /// <returns>Returns 0 on success</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| [Obsolete("use llama_decode() instead")] | |||
| public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); | |||
| public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past); | |||
| /// <summary> | |||
| /// Convert the provided text into tokens. | |||
| @@ -181,7 +179,7 @@ namespace LLama.Native | |||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | |||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||
| /// </returns> | |||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) | |||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, LLamaToken[] tokens, int n_max_tokens, bool add_bos, bool special) | |||
| { | |||
| unsafe | |||
| { | |||
| @@ -202,7 +200,7 @@ namespace LLama.Native | |||
| // Do the actual tokenization | |||
| fixed (byte* arrayPtr = array) | |||
| fixed (llama_token* tokensPtr = tokens) | |||
| fixed (LLamaToken* tokensPtr = tokens) | |||
| return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); | |||
| } | |||
| finally | |||
| @@ -213,13 +211,13 @@ namespace LLama.Native | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); | |||
| public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token); | |||
| public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token); | |||
| public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token); | |||
| /// <summary> | |||
| /// Get the size of the context window for the model for this context | |||
| @@ -272,21 +270,21 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_token_bos(SafeLlamaModelHandle model); | |||
| public static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model); | |||
| /// <summary> | |||
| /// Get the "End of sentence" token | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_token_eos(SafeLlamaModelHandle model); | |||
| public static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model); | |||
| /// <summary> | |||
| /// Get the "new line" token | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); | |||
| public static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model); | |||
| /// <summary> | |||
| /// Returns -1 if unknown, 1 for true or 0 for false. | |||
| @@ -477,7 +475,7 @@ namespace LLama.Native | |||
| /// <param name="llamaToken"></param> | |||
| /// <param name="buffer">buffer to write string into</param> | |||
| /// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns> | |||
| public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span<byte> buffer) | |||
| public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span<byte> buffer) | |||
| { | |||
| unsafe | |||
| { | |||
| @@ -488,7 +486,7 @@ namespace LLama.Native | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")] | |||
| static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length); | |||
| static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length); | |||
| } | |||
| /// <summary> | |||
| @@ -505,7 +503,7 @@ namespace LLama.Native | |||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||
| /// </returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); | |||
| public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_bos, bool special); | |||
| /// <summary> | |||
| /// Register a callback to receive llama log messages | |||
| @@ -153,19 +153,19 @@ namespace LLama.Native | |||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) | |||
| public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) | |||
| { | |||
| ThrowIfDisposed(); | |||
| if (string.IsNullOrEmpty(text) && !add_bos) | |||
| return Array.Empty<int>(); | |||
| return Array.Empty<LLamaToken>(); | |||
| // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't | |||
| // possibly be more than this. | |||
| var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); | |||
| // "Rent" an array to write results into (avoiding an allocation of a large array) | |||
| var temporaryArray = ArrayPool<int>.Shared.Rent(count); | |||
| var temporaryArray = ArrayPool<LLamaToken>.Shared.Rent(count); | |||
| try | |||
| { | |||
| // Do the actual conversion | |||
| @@ -177,14 +177,14 @@ namespace LLama.Native | |||
| } | |||
| // Copy the results from the rented into an array which is exactly the right size | |||
| var result = new int[n]; | |||
| var result = new LLamaToken[n]; | |||
| Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); | |||
| return result; | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<int>.Shared.Return(temporaryArray); | |||
| ArrayPool<LLamaToken>.Shared.Return(temporaryArray); | |||
| } | |||
| } | |||
| @@ -194,7 +194,7 @@ namespace LLama.Native | |||
| /// <param name="token">Token to decode</param> | |||
| /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | |||
| /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | |||
| public int TokenToSpan(int token, Span<byte> dest) | |||
| public int TokenToSpan(LLamaToken token, Span<byte> dest) | |||
| { | |||
| return ThrowIfDisposed().TokenToSpan(token, dest); | |||
| } | |||
| @@ -207,11 +207,11 @@ namespace LLama.Native | |||
| /// <param name="n_past">the number of tokens to use from previous eval calls</param> | |||
| /// <returns>Returns true on success</returns> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public bool Eval(ReadOnlySpan<int> tokens, int n_past) | |||
| public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past) | |||
| { | |||
| unsafe | |||
| { | |||
| fixed (int* pinned = tokens) | |||
| fixed (LLamaToken* pinned = tokens) | |||
| { | |||
| // the entire `eval` system needs replacing with the new batch system! | |||
| var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); | |||
| @@ -119,7 +119,7 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="token"></param> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(ctx, this, token); | |||
| } | |||
| @@ -123,12 +123,12 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Convert a single llama token into bytes | |||
| /// </summary> | |||
| /// <param name="llama_token">Token to decode</param> | |||
| /// <param name="token">Token to decode</param> | |||
| /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | |||
| /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | |||
| public int TokenToSpan(int llama_token, Span<byte> dest) | |||
| public int TokenToSpan(LLamaToken token, Span<byte> dest) | |||
| { | |||
| var length = NativeApi.llama_token_to_piece(this, llama_token, dest); | |||
| var length = NativeApi.llama_token_to_piece(this, token, dest); | |||
| return Math.Abs(length); | |||
| } | |||
| @@ -143,12 +143,10 @@ namespace LLama.Native | |||
| /// filled with as many characters as possible, starting from the _last_ token. | |||
| /// </returns> | |||
| [Obsolete("Use a StreamingTokenDecoder instead")] | |||
| internal Span<char> TokensToSpan(IReadOnlyList<int> tokens, Span<char> dest, Encoding encoding) | |||
| internal Span<char> TokensToSpan(IReadOnlyList<LLamaToken> tokens, Span<char> dest, Encoding encoding) | |||
| { | |||
| var decoder = new StreamingTokenDecoder(encoding, this); | |||
| foreach (var token in tokens) | |||
| decoder.Add(token); | |||
| decoder.AddRange(tokens); | |||
| var str = decoder.Read(); | |||
| @@ -172,7 +170,7 @@ namespace LLama.Native | |||
| /// <param name="encoding"></param> | |||
| /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param> | |||
| /// <returns></returns> | |||
| public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) | |||
| public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) | |||
| { | |||
| // Convert string to bytes, adding one extra byte to the end (null terminator) | |||
| var bytesCount = encoding.GetByteCount(text); | |||
| @@ -191,11 +189,11 @@ namespace LLama.Native | |||
| fixed (byte* bytesPtr = &bytes[0]) | |||
| { | |||
| // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | |||
| var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special); | |||
| var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special); | |||
| // Tokenize again, this time outputting into an array of exactly the right size | |||
| var tokens = new int[count]; | |||
| fixed (int* tokensPtr = &tokens[0]) | |||
| var tokens = new LLamaToken[count]; | |||
| fixed (LLamaToken* tokensPtr = &tokens[0]) | |||
| { | |||
| NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); | |||
| return tokens; | |||
| @@ -4,8 +4,6 @@ | |||
| namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Direct translation of the llama.cpp sampling API | |||
| /// </summary> | |||
| @@ -110,7 +108,7 @@ namespace LLama.Native | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] | |||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | |||
| public static LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | |||
| { | |||
| return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); | |||
| } | |||
| @@ -125,7 +123,7 @@ namespace LLama.Native | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] | |||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | |||
| public static LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | |||
| { | |||
| return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); | |||
| } | |||
| @@ -137,7 +135,7 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] | |||
| public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| public static LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| return candidates.SampleTokenGreedy(ctx); | |||
| } | |||
| @@ -149,7 +147,7 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleToken() method")] | |||
| public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| public static LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| @@ -12,21 +12,21 @@ public abstract class BaseSamplingPipeline | |||
| : ISamplingPipeline | |||
| { | |||
| private int _savedLogitsCount; | |||
| private (int index, float logit)[]? _savedLogits; | |||
| private (LLamaToken index, float logit)[]? _savedLogits; | |||
| /// <inheritdoc/> | |||
| public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| var protectedLogits = GetProtectedTokens(ctx); | |||
| _savedLogitsCount = protectedLogits.Count; | |||
| _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); | |||
| _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); | |||
| try | |||
| { | |||
| // Save the values of protected logits | |||
| for (var i = 0; i < protectedLogits.Count; i++) | |||
| { | |||
| var index = protectedLogits[i]; | |||
| var value = logits[index]; | |||
| var value = logits[(int)index]; | |||
| _savedLogits[i] = (index, value); | |||
| } | |||
| @@ -47,7 +47,7 @@ public abstract class BaseSamplingPipeline | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<(int, float)>.Shared.Return(_savedLogits); | |||
| ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); | |||
| _savedLogits = null; | |||
| _savedLogitsCount = 0; | |||
| } | |||
| @@ -58,7 +58,7 @@ public abstract class BaseSamplingPipeline | |||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||
| protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| @@ -74,7 +74,7 @@ public abstract class BaseSamplingPipeline | |||
| // Restore the values of protected logits | |||
| for (var i = 0; i < saved.Length; i++) | |||
| logits[saved[i].index] = saved[i].logit; | |||
| logits[(int)saved[i].index] = saved[i].logit; | |||
| } | |||
| /// <summary> | |||
| @@ -96,7 +96,7 @@ public abstract class BaseSamplingPipeline | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||
| /// <summary> | |||
| /// Process the LLamaTokenDataArray and select a single token | |||
| @@ -105,7 +105,7 @@ public abstract class BaseSamplingPipeline | |||
| /// <param name="candidates">The LLamaTokenDataArray data produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens); | |||
| protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens); | |||
| /// <summary> | |||
| /// Choose the final token from the candidates | |||
| @@ -113,7 +113,7 @@ public abstract class BaseSamplingPipeline | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <returns></returns> | |||
| protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); | |||
| protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); | |||
| /// <inheritdoc/> | |||
| public virtual void Reset() | |||
| @@ -99,27 +99,27 @@ public sealed class DefaultSamplingPipeline | |||
| /// </summary> | |||
| public bool PenalizeNewline { get; set; } = false; | |||
| private readonly int[] _newlineToken = new int[1]; | |||
| private readonly LLamaToken[] _newlineToken = new LLamaToken[1]; | |||
| /// <inheritdoc /> | |||
| protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||
| protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||
| { | |||
| if (PenalizeNewline) | |||
| return Array.Empty<int>(); | |||
| return Array.Empty<LLamaToken>(); | |||
| _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); | |||
| return _newlineToken; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| foreach (var (key, value) in LogitBias) | |||
| logits[key] += value; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| // Apply penalties to candidates | |||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||
| @@ -142,7 +142,7 @@ public sealed class DefaultSamplingPipeline | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| @@ -19,7 +19,7 @@ public interface ISamplingPipeline | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||
| LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||
| /// <summary> | |||
| /// Reset all internal state of the sampling pipeline | |||
| @@ -40,13 +40,13 @@ public static class ISamplingPipelineExtensions | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens) | |||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens) | |||
| { | |||
| #if NET5_0_OR_GREATER | |||
| var span = CollectionsMarshal.AsSpan(lastTokens); | |||
| return pipeline.Sample(ctx, logits, span); | |||
| #else | |||
| var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count); | |||
| var copy = ArrayPool<LLamaToken>.Shared.Rent(lastTokens.Count); | |||
| try | |||
| { | |||
| lastTokens.CopyTo(copy); | |||
| @@ -54,7 +54,7 @@ public static class ISamplingPipelineExtensions | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<int>.Shared.Return(copy); | |||
| ArrayPool<LLamaToken>.Shared.Return(copy); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -69,7 +69,7 @@ namespace LLama | |||
| /// Add a single token to the decoder | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| public void Add(int token) | |||
| public void Add(LLamaToken token) | |||
| { | |||
| var charsArr = ArrayPool<char>.Shared.Rent(16); | |||
| var bytesArr = ArrayPool<byte>.Shared.Rent(16); | |||
| @@ -108,7 +108,7 @@ namespace LLama | |||
| // Converts a single token into bytes, using the `bytes` array as temporary storage. | |||
| // If the `bytes` array is too small it will get a larger one from the ArrayPool. | |||
| static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) | |||
| static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model) | |||
| { | |||
| // Try to get bytes | |||
| var l = model.TokenToSpan(token, bytes); | |||
| @@ -129,6 +129,15 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add a single token to the decoder | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| public void Add(int token) | |||
| { | |||
| Add((LLamaToken)token); | |||
| } | |||
| /// <summary> | |||
| /// Add all tokens in the given enumerable | |||
| /// </summary> | |||
| @@ -139,6 +148,16 @@ namespace LLama | |||
| Add(item); | |||
| } | |||
| /// <summary> | |||
| /// Add all tokens in the given enumerable | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| public void AddRange(IEnumerable<LLamaToken> tokens) | |||
| { | |||
| foreach (var item in tokens) | |||
| Add((int)item); | |||
| } | |||
| /// <summary> | |||
| /// Read all decoded characters and clear the buffer | |||
| /// </summary> | |||