Browse Source

Merge pull request #404 from martindevans/switched_to_LLamaToken_struct

LLamaToken Struct
tags/v0.10.0
Martin Evans GitHub 1 year ago
parent
commit
402a110a3a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 196 additions and 168 deletions
  1. +1
    -2
      LLama.Examples/Examples/BatchedDecoding.cs
  2. +4
    -3
      LLama.Unittest/LLamaContextTests.cs
  3. +1
    -1
      LLama.Web/Common/InferenceOptions.cs
  4. +1
    -1
      LLama/Abstractions/IInferenceParams.cs
  5. +1
    -3
      LLama/Common/InferenceParams.cs
  6. +2
    -2
      LLama/Extensions/IReadOnlyListExtensions.cs
  7. +17
    -19
      LLama/LLamaContext.cs
  8. +10
    -11
      LLama/LLamaExecutorBase.cs
  9. +6
    -7
      LLama/LLamaInstructExecutor.cs
  10. +3
    -4
      LLama/LLamaInteractExecutor.cs
  11. +3
    -5
      LLama/LLamaStatelessExecutor.cs
  12. +3
    -3
      LLama/LLamaWeights.cs
  13. +7
    -9
      LLama/Native/LLamaBatchSafeHandle.cs
  14. +3
    -5
      LLama/Native/LLamaBeamView.cs
  15. +2
    -5
      LLama/Native/LLamaNativeBatch.cs
  16. +38
    -0
      LLama/Native/LLamaToken.cs
  17. +2
    -2
      LLama/Native/LLamaTokenData.cs
  18. +8
    -10
      LLama/Native/LLamaTokenDataArray.cs
  19. +1
    -3
      LLama/Native/NativeApi.Grammar.cs
  20. +6
    -9
      LLama/Native/NativeApi.Sampling.cs
  21. +14
    -16
      LLama/Native/NativeApi.cs
  22. +8
    -8
      LLama/Native/SafeLLamaContextHandle.cs
  23. +1
    -1
      LLama/Native/SafeLLamaGrammarHandle.cs
  24. +9
    -11
      LLama/Native/SafeLlamaModelHandle.cs
  25. +4
    -6
      LLama/Native/SamplingApi.cs
  26. +10
    -10
      LLama/Sampling/BaseSamplingPipeline.cs
  27. +6
    -6
      LLama/Sampling/DefaultSamplingPipeline.cs
  28. +4
    -4
      LLama/Sampling/ISamplingPipeline.cs
  29. +21
    -2
      LLama/StreamingTokenDecoder.cs

+ 1
- 2
LLama.Examples/Examples/BatchedDecoding.cs View File

@@ -1,6 +1,5 @@
using System.Diagnostics;
using System.Text;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;

@@ -94,7 +93,7 @@ public class BatchedDecoding
var n_cur = batch.NativeBatch.n_tokens;
var n_decode = 0;

var streams = new List<int>[n_parallel];
var streams = new List<LLamaToken>[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new();



+ 4
- 3
LLama.Unittest/LLamaContextTests.cs View File

@@ -1,4 +1,5 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Unittest
{
@@ -37,7 +38,7 @@ namespace LLama.Unittest
{
var tokens = _context.Tokenize("The quick brown fox", true);

Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
@@ -45,7 +46,7 @@ namespace LLama.Unittest
{
var tokens = _context.Tokenize("The quick brown fox", false);

Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new LLamaToken[] { 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
@@ -53,7 +54,7 @@ namespace LLama.Unittest
{
var tokens = _context.Tokenize("", false);

Assert.Equal(Array.Empty<int>(), tokens);
Assert.Equal(Array.Empty<LLamaToken>(), tokens);
}
}
}

+ 1
- 1
LLama.Web/Common/InferenceOptions.cs View File

@@ -17,7 +17,7 @@ namespace LLama.Web.Common
public int MaxTokens { get; set; } = -1;

/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();


+ 1
- 1
LLama/Abstractions/IInferenceParams.cs View File

@@ -24,7 +24,7 @@ namespace LLama.Abstractions
/// <summary>
/// logit bias for specific tokens
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }
public Dictionary<LLamaToken, float>? LogitBias { get; set; }

/// <summary>
/// Sequences where the model will stop generating further tokens.


+ 1
- 3
LLama/Common/InferenceParams.cs View File

@@ -6,8 +6,6 @@ using LLama.Sampling;

namespace LLama.Common
{
using llama_token = Int32;

/// <summary>
/// The paramters used for inference.
/// </summary>
@@ -28,7 +26,7 @@ namespace LLama.Common
/// <summary>
/// logit bias for specific tokens
/// </summary>
public Dictionary<llama_token, float>? LogitBias { get; set; } = null;
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.


+ 2
- 2
LLama/Extensions/IReadOnlyListExtensions.cs View File

@@ -38,7 +38,7 @@ namespace LLama.Extensions
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TTokens : IReadOnlyList<LLamaToken>
where TQueries : IReadOnlyList<string>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
@@ -79,7 +79,7 @@ namespace LLama.Extensions
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TTokens : IReadOnlyList<LLamaToken>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;


+ 17
- 19
LLama/LLamaContext.cs View File

@@ -15,8 +15,6 @@ using Microsoft.Extensions.Logging;

namespace LLama
{
using llama_token = Int32;

/// <summary>
/// A llama_context, which holds all the context required to interact with a model
/// </summary>
@@ -93,7 +91,7 @@ namespace LLama
/// <param name="addBos">Whether to add a bos to the text.</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false)
{
return NativeHandle.Tokenize(text, addBos, special, Encoding);
}
@@ -104,7 +102,7 @@ namespace LLama
/// <param name="tokens"></param>
/// <returns></returns>
[Obsolete("Use a `StreamingTokenDecoder` instead")]
public string DeTokenize(IReadOnlyList<llama_token> tokens)
public string DeTokenize(IReadOnlyList<LLamaToken> tokens)
{
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
// It should be kept around for the entire time you are decoding one stream of tokens.
@@ -219,7 +217,7 @@ namespace LLama
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}
@@ -240,11 +238,11 @@ namespace LLama
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;
LLamaToken id;

if (grammar != null)
{
@@ -301,7 +299,7 @@ namespace LLama
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null,
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
@@ -311,12 +309,12 @@ namespace LLama
if (logitBias is not null)
{
foreach (var (key, value) in logitBias)
logits[key] += value;
logits[(int)key] += value;
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[(int)nl_token];

// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
@@ -353,7 +351,7 @@ namespace LLama
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(llama_token[] tokens, int pastTokensCount)
public int Eval(LLamaToken[] tokens, int pastTokensCount)
{
return Eval(tokens.AsSpan(), pastTokensCount);
}
@@ -366,7 +364,7 @@ namespace LLama
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(List<llama_token> tokens, int pastTokensCount)
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
@@ -376,7 +374,7 @@ namespace LLama
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count);
var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
@@ -384,7 +382,7 @@ namespace LLama
}
finally
{
System.Buffers.ArrayPool<llama_token>.Shared.Return(rented);
System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented);
}
#endif
}
@@ -397,7 +395,7 @@ namespace LLama
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}
@@ -410,7 +408,7 @@ namespace LLama
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += (int)Params.BatchSize)


+ 10
- 11
LLama/LLamaExecutorBase.cs View File

@@ -14,7 +14,6 @@ using System.Threading.Tasks;

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The base class for stateful LLama executors.
/// </summary>
@@ -47,19 +46,19 @@ namespace LLama
/// <summary>
/// A container of the tokens to be processed and after processed.
/// </summary>
protected List<llama_token> _embeds = new(); // embd
protected List<LLamaToken> _embeds = new(); // embd
/// <summary>
/// A container for the tokens of input.
/// </summary>
protected List<llama_token> _embed_inps = new();
protected List<LLamaToken> _embed_inps = new();
/// <summary>
///
/// </summary>
protected List<llama_token> _session_tokens = new();
protected List<LLamaToken> _session_tokens = new();
/// <summary>
/// The last tokens generated by the model.
/// </summary>
protected FixedSizeQueue<llama_token> _last_n_tokens;
protected FixedSizeQueue<LLamaToken> _last_n_tokens;
/// <summary>
/// The context used by the executor.
/// </summary>
@@ -84,7 +83,7 @@ namespace LLama
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}

@@ -105,7 +104,7 @@ namespace LLama
if (File.Exists(filename))
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
var session_tokens = new llama_token[Context.ContextSize];
var session_tokens = new LLamaToken[Context.ContextSize];
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
@@ -361,16 +360,16 @@ namespace LLama
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }
public List<LLamaToken> Embeds { get; set; }

[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }
public List<LLamaToken> EmbedInps { get; set; }

[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }
public List<LLamaToken> SessionTokens { get; set; }

[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
public LLamaToken[] LastTokens { get; set; }

[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }


+ 6
- 7
LLama/LLamaInstructExecutor.cs View File

@@ -13,7 +13,6 @@ using Microsoft.Extensions.Logging;

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The LLama executor for instruct mode.
/// </summary>
@@ -22,8 +21,8 @@ namespace LLama
{
private bool _is_prompt_run = true;
private readonly string _instructionPrefix;
private llama_token[] _inp_pfx;
private llama_token[] _inp_sfx;
private LLamaToken[] _inp_pfx;
private LLamaToken[] _inp_sfx;

/// <summary>
///
@@ -75,7 +74,7 @@ namespace LLama
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_inp_pfx = state.InputPrefixTokens;
_inp_sfx = state.InputSuffixTokens;
_n_matching_session_tokens = state.MatchingSessionTokensCount;
@@ -210,7 +209,7 @@ namespace LLama
SaveSessionFile(_pathSession);
}

llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
@@ -266,12 +265,12 @@ namespace LLama
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public llama_token[] InputPrefixTokens { get; set; }
public LLamaToken[] InputPrefixTokens { get; set; }
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public llama_token[] InputSuffixTokens { get; set; }
public LLamaToken[] InputSuffixTokens { get; set; }
}
}
}

+ 3
- 4
LLama/LLamaInteractExecutor.cs View File

@@ -13,14 +13,13 @@ using Microsoft.Extensions.Logging;

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The LLama executor for interactive mode.
/// </summary>
public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;
private readonly llama_token _llama_token_newline;
private readonly LLamaToken _llama_token_newline;

/// <summary>
///
@@ -63,7 +62,7 @@ namespace LLama
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_n_matching_session_tokens = state.MatchingSessionTokensCount;
_pastTokensCount = state.PastTokensCount;
_pathSession = state.SessionFilePath;
@@ -189,7 +188,7 @@ namespace LLama
SaveSessionFile(_pathSession);
}

llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());


+ 3
- 5
LLama/LLamaStatelessExecutor.cs View File

@@ -12,8 +12,6 @@ using Microsoft.Extensions.Logging;

namespace LLama
{
using llama_token = Int32;

/// <summary>
/// This executor infer the input as one-time job. Previous inputs won't impact on the
/// response to current input.
@@ -71,9 +69,9 @@ namespace LLama

// Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<llama_token>(repeat_last_n);
var lastTokens = new List<LLamaToken>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);
lastTokens.Add((LLamaToken)0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
@@ -89,7 +87,7 @@ namespace LLama
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);


+ 3
- 3
LLama/LLamaWeights.cs View File

@@ -42,17 +42,17 @@ namespace LLama
/// <summary>
/// Get the newline token for this model
/// </summary>
public int NewlineToken => NativeApi.llama_token_nl(NativeHandle);
public LLamaToken NewlineToken => NativeApi.llama_token_nl(NativeHandle);

/// <summary>
/// Get the "end of sentence" token for this model
/// </summary>
public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);
public LLamaToken EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);

/// <summary>
/// Get the "beginning of sentence" token for this model
/// </summary>
public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);
public LLamaToken BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);

/// <summary>
/// Dimension of embedding vectors


+ 7
- 9
LLama/Native/LLamaBatchSafeHandle.cs View File

@@ -2,8 +2,6 @@

namespace LLama.Native;

using llama_token = Int32;

/// <summary>
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
/// </summary>
@@ -20,16 +18,16 @@ public sealed class LLamaBatchSafeHandle
/// <summary>
/// the token ids of the input (used when embd is NULL)
/// </summary>
public Span<llama_token> Token
public Span<LLamaToken> Token
{
get
{
unsafe
{
if (_embd != 0)
return new Span<int>(null, 0);
return new Span<LLamaToken>(null, 0);
else
return new Span<int>(NativeBatch.token, NativeBatch.n_tokens);
return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
@@ -37,7 +35,7 @@ public sealed class LLamaBatchSafeHandle
/// <summary>
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
/// </summary>
public Span<llama_token> Embed
public Span<LLamaToken> Embed
{
get
{
@@ -47,9 +45,9 @@ public sealed class LLamaBatchSafeHandle
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token

if (_embd != 0)
return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span<llama_token>(null, 0);
return new Span<LLamaToken>(null, 0);
}
}
}
@@ -133,7 +131,7 @@ public sealed class LLamaBatchSafeHandle
/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
/// </summary>
public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
unsafe
{


+ 3
- 5
LLama/Native/LLamaBeamView.cs View File

@@ -3,15 +3,13 @@ using System.Runtime.InteropServices;

namespace LLama.Native;

using llama_token = Int32;

/// <summary>
/// Information about a single beam in a beam search
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaBeamView
{
private unsafe llama_token* tokens;
private unsafe LLamaToken* tokens;
private nint n_tokens;

/// <summary>
@@ -27,7 +25,7 @@ public struct LLamaBeamView
/// <summary>
/// Tokens in this beam
/// </summary>
public readonly Span<llama_token> Tokens
public readonly Span<LLamaToken> Tokens
{
get
{
@@ -35,7 +33,7 @@ public struct LLamaBeamView
{
if (n_tokens > int.MaxValue)
throw new InvalidOperationException("More than 2147483647 tokens is not supported");
return new Span<llama_token>(tokens, (int)n_tokens);
return new Span<LLamaToken>(tokens, (int)n_tokens);
}
}
}

+ 2
- 5
LLama/Native/LLamaNativeBatch.cs View File

@@ -1,10 +1,7 @@
using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices;

namespace LLama.Native;

using llama_token = Int32;

/// <summary>
/// Input data for llama_decode
/// A llama_batch object can contain input about one or many sequences
@@ -21,7 +18,7 @@ public unsafe struct LLamaNativeBatch
/// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
/// </summary>
public llama_token* token;
public LLamaToken* token;

/// <summary>
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created


+ 38
- 0
LLama/Native/LLamaToken.cs View File

@@ -0,0 +1,38 @@
using System.Runtime.InteropServices;

namespace LLama.Native;

/// <summary>
/// A single token
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly record struct LLamaToken
{
/// <summary>
/// The raw value
/// </summary>
private readonly int Value;

/// <summary>
/// Create a new LLamaToken
/// </summary>
/// <param name="value"></param>
private LLamaToken(int value)
{
Value = value;
}

/// <summary>
/// Convert a LLamaToken into an integer (extract the raw value)
/// </summary>
/// <param name="pos"></param>
/// <returns></returns>
public static explicit operator int(LLamaToken pos) => pos.Value;

/// <summary>
/// Convert an integer into a LLamaToken
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static implicit operator LLamaToken(int value) => new(value);
}

+ 2
- 2
LLama/Native/LLamaTokenData.cs View File

@@ -11,7 +11,7 @@ public struct LLamaTokenData
/// <summary>
/// token id
/// </summary>
public int id;
public LLamaToken id;

/// <summary>
/// log-odds of the token
@@ -29,7 +29,7 @@ public struct LLamaTokenData
/// <param name="id"></param>
/// <param name="logit"></param>
/// <param name="p"></param>
public LLamaTokenData(int id, float logit, float p)
public LLamaTokenData(LLamaToken id, float logit, float p)
{
this.id = id;
this.logit = logit;


+ 8
- 10
LLama/Native/LLamaTokenDataArray.cs View File

@@ -2,8 +2,6 @@
using System.Buffers;
using System.Runtime.InteropServices;

using llama_token = System.Int32;

namespace LLama.Native
{
/// <summary>
@@ -41,7 +39,7 @@ namespace LLama.Native
{
var candidates = new LLamaTokenData[logits.Length];
for (var token_id = 0; token_id < logits.Length; token_id++)
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f);

return new LLamaTokenDataArray(candidates);
}
@@ -50,7 +48,7 @@ namespace LLama.Native
/// Overwrite the logit values for all given tokens
/// </summary>
/// <param name="values">tuples of token and logit value to overwrite</param>
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values)
{
if (values.Length == 0)
return;
@@ -172,13 +170,13 @@ namespace LLama.Native
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLamaToken> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
fixed (int* last_tokens_handle = last_tokens)
fixed (LLamaToken* last_tokens_handle = last_tokens)
{
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
@@ -220,7 +218,7 @@ namespace LLama.Native
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public int SampleToken(SafeLLamaContextHandle context)
public LLamaToken SampleToken(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
@@ -235,7 +233,7 @@ namespace LLama.Native
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public int SampleTokenGreedy(SafeLLamaContextHandle context)
public LLamaToken SampleTokenGreedy(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
@@ -254,7 +252,7 @@ namespace LLama.Native
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
public LLamaToken SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
@@ -272,7 +270,7 @@ namespace LLama.Native
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
public LLamaToken SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{


+ 1
- 3
LLama/Native/NativeApi.Grammar.cs View File

@@ -3,8 +3,6 @@ using System.Runtime.InteropServices;

namespace LLama.Native
{
using llama_token = Int32;

public static partial class NativeApi
{
/// <summary>
@@ -48,6 +46,6 @@ namespace LLama.Native
/// <param name="grammar"></param>
/// <param name="token"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token);
public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, LLamaToken token);
}
}

+ 6
- 9
LLama/Native/NativeApi.Sampling.cs View File

@@ -1,10 +1,7 @@
using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices;

namespace LLama.Native
{
using llama_token = Int32;

public static partial class NativeApi
{
/// <summary>
@@ -21,7 +18,7 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
ref LLamaTokenDataArrayNative candidates,
llama_token* last_tokens, ulong last_tokens_size,
LLamaToken* last_tokens, ulong last_tokens_size,
float penalty_repeat,
float penalty_freq,
float penalty_present);
@@ -115,7 +112,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu);
public static extern LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu);

/// <summary>
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -127,7 +124,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu);
public static extern LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu);

/// <summary>
/// Selects the token with the highest probability.
@@ -136,7 +133,7 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);

/// <summary>
/// Randomly selects a token from the candidates based on their probabilities.
@@ -145,6 +142,6 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
public static extern LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
}
}

+ 14
- 16
LLama/Native/NativeApi.cs View File

@@ -7,8 +7,6 @@ using System.Text;

namespace LLama.Native
{
using llama_token = Int32;

/// <summary>
/// Callback from llama.cpp with log messages
/// </summary>
@@ -141,7 +139,7 @@ namespace LLama.Native
/// <param name="n_token_count_out"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);
public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);

/// <summary>
/// Save session file
@@ -152,7 +150,7 @@ namespace LLama.Native
/// <param name="n_token_count"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens, ulong n_token_count);
public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
@@ -166,7 +164,7 @@ namespace LLama.Native
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[Obsolete("use llama_decode() instead")]
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past);

/// <summary>
/// Convert the provided text into tokens.
@@ -181,7 +179,7 @@ namespace LLama.Native
/// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special)
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, LLamaToken[] tokens, int n_max_tokens, bool add_bos, bool special)
{
unsafe
{
@@ -202,7 +200,7 @@ namespace LLama.Native

// Do the actual tokenization
fixed (byte* arrayPtr = array)
fixed (llama_token* tokensPtr = tokens)
fixed (LLamaToken* tokensPtr = tokens)
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
}
finally
@@ -213,13 +211,13 @@ namespace LLama.Native
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token);
public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token);
public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token);

/// <summary>
/// Get the size of the context window for the model for this context
@@ -272,21 +270,21 @@ namespace LLama.Native
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_bos(SafeLlamaModelHandle model);
public static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model);

/// <summary>
/// Get the "End of sentence" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_eos(SafeLlamaModelHandle model);
public static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model);

/// <summary>
/// Get the "new line" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);
public static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model);

/// <summary>
/// Returns -1 if unknown, 1 for true or 0 for false.
@@ -477,7 +475,7 @@ namespace LLama.Native
/// <param name="llamaToken"></param>
/// <param name="buffer">buffer to write string into</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span<byte> buffer)
public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span<byte> buffer)
{
unsafe
{
@@ -488,7 +486,7 @@ namespace LLama.Native
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length);
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length);
}

/// <summary>
@@ -505,7 +503,7 @@ namespace LLama.Native
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);
public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_bos, bool special);

/// <summary>
/// Register a callback to receive llama log messages


+ 8
- 8
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -153,19 +153,19 @@ namespace LLama.Native
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
ThrowIfDisposed();

if (string.IsNullOrEmpty(text) && !add_bos)
return Array.Empty<int>();
return Array.Empty<LLamaToken>();

// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);

// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<int>.Shared.Rent(count);
var temporaryArray = ArrayPool<LLamaToken>.Shared.Rent(count);
try
{
// Do the actual conversion
@@ -177,14 +177,14 @@ namespace LLama.Native
}

// Copy the results from the rented into an array which is exactly the right size
var result = new int[n];
var result = new LLamaToken[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);

return result;
}
finally
{
ArrayPool<int>.Shared.Return(temporaryArray);
ArrayPool<LLamaToken>.Shared.Return(temporaryArray);
}
}

@@ -194,7 +194,7 @@ namespace LLama.Native
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int token, Span<byte> dest)
public int TokenToSpan(LLamaToken token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
@@ -207,11 +207,11 @@ namespace LLama.Native
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <returns>Returns true on success</returns>
[Obsolete("use llama_decode() instead")]
public bool Eval(ReadOnlySpan<int> tokens, int n_past)
public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past)
{
unsafe
{
fixed (int* pinned = tokens)
fixed (LLamaToken* pinned = tokens)
{
// the entire `eval` system needs replacing with the new batch system!
var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past);


+ 1
- 1
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -119,7 +119,7 @@ namespace LLama.Native
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token)
{
NativeApi.llama_grammar_accept_token(ctx, this, token);
}


+ 9
- 11
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -123,12 +123,12 @@ namespace LLama.Native
/// <summary>
/// Convert a single llama token into bytes
/// </summary>
/// <param name="llama_token">Token to decode</param>
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int llama_token, Span<byte> dest)
public int TokenToSpan(LLamaToken token, Span<byte> dest)
{
var length = NativeApi.llama_token_to_piece(this, llama_token, dest);
var length = NativeApi.llama_token_to_piece(this, token, dest);
return Math.Abs(length);
}

@@ -143,12 +143,10 @@ namespace LLama.Native
/// filled with as many characters as possible, starting from the _last_ token.
/// </returns>
[Obsolete("Use a StreamingTokenDecoder instead")]
internal Span<char> TokensToSpan(IReadOnlyList<int> tokens, Span<char> dest, Encoding encoding)
internal Span<char> TokensToSpan(IReadOnlyList<LLamaToken> tokens, Span<char> dest, Encoding encoding)
{
var decoder = new StreamingTokenDecoder(encoding, this);

foreach (var token in tokens)
decoder.Add(token);
decoder.AddRange(tokens);

var str = decoder.Read();

@@ -172,7 +170,7 @@ namespace LLama.Native
/// <param name="encoding"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
@@ -191,11 +189,11 @@ namespace LLama.Native
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special);
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special);

// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0])
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = &tokens[0])
{
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;


+ 4
- 6
LLama/Native/SamplingApi.cs View File

@@ -4,8 +4,6 @@

namespace LLama.Native
{
using llama_token = Int32;

/// <summary>
/// Direct translation of the llama.cpp sampling API
/// </summary>
@@ -110,7 +108,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")]
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
public static LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu);
}
@@ -125,7 +123,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")]
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
public static LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu);
}
@@ -137,7 +135,7 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")]
public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
public static LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleTokenGreedy(ctx);
}
@@ -149,7 +147,7 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleToken() method")]
public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
public static LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleToken(ctx);
}


+ 10
- 10
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -12,21 +12,21 @@ public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (int index, float logit)[]? _savedLogits;
private (LLamaToken index, float logit)[]? _savedLogits;

/// <inheritdoc/>
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
_savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[index];
var value = logits[(int)index];
_savedLogits[i] = (index, value);
}

@@ -47,7 +47,7 @@ public abstract class BaseSamplingPipeline
}
finally
{
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
@@ -58,7 +58,7 @@ public abstract class BaseSamplingPipeline
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
/// </summary>
/// <returns></returns>
protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx);
protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx);

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
@@ -74,7 +74,7 @@ public abstract class BaseSamplingPipeline

// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[saved[i].index] = saved[i].logit;
logits[(int)saved[i].index] = saved[i].logit;
}

/// <summary>
@@ -96,7 +96,7 @@ public abstract class BaseSamplingPipeline
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Process the LLamaTokenDataArray and select a single token
@@ -105,7 +105,7 @@ public abstract class BaseSamplingPipeline
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Choose the final token from the candidates
@@ -113,7 +113,7 @@ public abstract class BaseSamplingPipeline
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <returns></returns>
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);

/// <inheritdoc/>
public virtual void Reset()


+ 6
- 6
LLama/Sampling/DefaultSamplingPipeline.cs View File

@@ -99,27 +99,27 @@ public sealed class DefaultSamplingPipeline
/// </summary>
public bool PenalizeNewline { get; set; } = false;

private readonly int[] _newlineToken = new int[1];
private readonly LLamaToken[] _newlineToken = new LLamaToken[1];

/// <inheritdoc />
protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx)
{
if (PenalizeNewline)
return Array.Empty<int>();
return Array.Empty<LLamaToken>();

_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}

/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
foreach (var (key, value) in LogitBias)
logits[key] += value;
}

/// <inheritdoc />
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
@@ -142,7 +142,7 @@ public sealed class DefaultSamplingPipeline
}

/// <inheritdoc />
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleToken(ctx);
}

+ 4
- 4
LLama/Sampling/ISamplingPipeline.cs View File

@@ -19,7 +19,7 @@ public interface ISamplingPipeline
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns>
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Reset all internal state of the sampling pipeline
@@ -40,13 +40,13 @@ public static class ISamplingPipelineExtensions
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
return pipeline.Sample(ctx, logits, span);
#else
var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
var copy = ArrayPool<LLamaToken>.Shared.Rent(lastTokens.Count);
try
{
lastTokens.CopyTo(copy);
@@ -54,7 +54,7 @@ public static class ISamplingPipelineExtensions
}
finally
{
ArrayPool<int>.Shared.Return(copy);
ArrayPool<LLamaToken>.Shared.Return(copy);
}
#endif
}

+ 21
- 2
LLama/StreamingTokenDecoder.cs View File

@@ -69,7 +69,7 @@ namespace LLama
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
public void Add(LLamaToken token)
{
var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16);
@@ -108,7 +108,7 @@ namespace LLama

// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
@@ -129,6 +129,15 @@ namespace LLama
}
}

/// <summary>
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
{
Add((LLamaToken)token);
}

/// <summary>
/// Add all tokens in the given enumerable
/// </summary>
@@ -139,6 +148,16 @@ namespace LLama
Add(item);
}

/// <summary>
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<LLamaToken> tokens)
{
foreach (var item in tokens)
Add((int)item);
}

/// <summary>
/// Read all decoded characters and clear the buffer
/// </summary>


Loading…
Cancel
Save