Browse Source

Rewritten sampling API to be accessed through the `LLamaTokenDataArray` object

tags/v0.7.0^2
Martin Evans 2 years ago
parent
commit
e81b3023d5
7 changed files with 259 additions and 165 deletions
  1. +8
    -5
      LLama.Examples/NewVersion/BatchedDecoding.cs
  2. +1
    -1
      LLama.Examples/NewVersion/TestRunner.cs
  3. +15
    -20
      LLama/LLamaContext.cs
  4. +193
    -43
      LLama/Native/LLamaTokenDataArray.cs
  5. +9
    -13
      LLama/Native/NativeApi.Sampling.cs
  6. +10
    -0
      LLama/Native/SafeLLamaGrammarHandle.cs
  7. +23
    -83
      LLama/Native/SamplingApi.cs

+ 8
- 5
LLama.Examples/NewVersion/BatchedDecoding.cs View File

@@ -6,6 +6,10 @@ using LLama.Native;

namespace LLama.Examples.NewVersion;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
/// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
public class BatchedDecoding
{
private const int n_parallel = 8;
@@ -116,12 +120,11 @@ public class BatchedDecoding
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native);

candidates_native.TopK(context.NativeHandle, top_k);
candidates_native.TopP(context.NativeHandle, top_p);
candidates_native.Temperature(context.NativeHandle, temp);
var new_token_id = candidates_native.SampleToken(context.NativeHandle);
candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
candidates.Temperature(context.NativeHandle, temp);
var new_token_id = candidates.SampleToken(context.NativeHandle);

if (new_token_id == eos || new_token_id == nl)
{


+ 1
- 1
LLama.Examples/NewVersion/TestRunner.cs View File

@@ -22,7 +22,7 @@
Console.WriteLine("12: Semantic Kernel Chat.");
Console.WriteLine("13: Semantic Kernel Memory.");
Console.WriteLine("14: Coding Assistant.");
Console.WriteLine("15: Batch Decoding Benchmark.");
Console.WriteLine("15: Batch Decoding.");

while (true)
{


+ 15
- 20
LLama/LLamaContext.cs View File

@@ -235,13 +235,13 @@ namespace LLama

if (grammar != null)
{
SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
candidates.ApplyGrammar(NativeHandle, grammar);
}

if (temperature <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
id = candidates.SampleTokenGreedy(NativeHandle);
}
else
{
@@ -250,32 +250,28 @@ namespace LLama
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token(NativeHandle, candidates);
candidates.TopK(NativeHandle, topK);
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}
}
mirostat_mu = mu;
}

if (grammar != null)
{
NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
}
grammar?.AcceptToken(NativeHandle, id);

return id;
}
@@ -305,7 +301,7 @@ namespace LLama
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];

// Convert logits into token candidates
@@ -316,8 +312,7 @@ namespace LLama
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();

// Apply penalties to candidates
SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);
candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence);

// Restore newline token logit value if necessary
if (!penalizeNL)


+ 193
- 43
LLama/Native/LLamaTokenDataArray.cs View File

@@ -45,6 +45,199 @@ namespace LLama.Native

return new LLamaTokenDataArray(candidates);
}

#region sampling
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
sorted = st.sorted;
}
}

/// <summary>
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// </summary>
/// <param name="context"></param>
/// <param name="k">Number of tokens to keep</param>
/// <param name="minKeep">Minimum number to keep</param>
public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="minKeep"></param>
public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>
/// <param name="context"></param>
/// <param name="z"></param>
/// <param name="min_keep"></param>
public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
sorted = st.sorted;
}
}

/// <summary>
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_typical(context, ref st, p, min_keep);
sorted = st.sorted;
}
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="context"></param>
/// <param name="last_tokens"></param>
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
using (var last_tokens_handle = last_tokens.Pin())
{
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
}
}
}

/// <summary>
/// Sample with temperature.
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
/// </summary>
/// <param name="context"></param>
/// <param name="temp"></param>
public void Temperature(SafeLLamaContextHandle context, float temp)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_temperature(context, ref st, temp);
sorted = st.sorted;
}
}

/// <summary>
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// </summary>
/// <param name="context"></param>
public void Softmax(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_softmax(context, ref st);
sorted = st.sorted;
}
}

/// <summary>
/// Randomly selects a token from the candidates based on their probabilities.
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public int SampleToken(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token(context, ref st);
sorted = st.sorted;
return token;
}
}

/// <summary>
/// Selects the token with the highest probability.
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public int SampleTokenGreedy(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_greedy(context, ref st);
sorted = st.sorted;
return token;
}
}

/// <summary>
/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// </summary>
/// <param name="context"></param>
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
sorted = st.sorted;
return token;
}
}

/// <summary>
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// </summary>
/// <param name="context"></param>
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
sorted = st.sorted;
return token;
}
}
#endregion
}

/// <summary>
@@ -96,48 +289,5 @@ namespace LLama.Native

return handle;
}

/// <summary>
/// Perform TopK sampling, sorting the data and reducing the size to k
/// </summary>
/// <param name="context"></param>
/// <param name="k">Number of tokens to keep</param>
/// <param name="minKeep">Minimum number to keep</param>
public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
{
NativeApi.llama_sample_top_k(context, ref this, k, minKeep);
}

/// <summary>
/// Perform top p sampling, sorting the data and keeping only logits more likely than p
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="minKeep"></param>
public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
NativeApi.llama_sample_top_p(context, ref this, p, minKeep);
}

/// <summary>
/// Apply temperature to logits
/// </summary>
/// <param name="context"></param>
/// <param name="temp"></param>
public void Temperature(SafeLLamaContextHandle context, float temp)
{
NativeApi.llama_sample_temperature(context, ref this, temp);
}

/// <summary>
/// Sample a token from the set of possible tokens
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public int SampleToken(SafeLLamaContextHandle context)
{
return NativeApi.llama_sample_token(context, ref this);
}
}
}

+ 9
- 13
LLama/Native/NativeApi.Sampling.cs View File

@@ -9,26 +9,22 @@ namespace LLama.Native
{
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty);

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
/// <param name="penalty_repeat">Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.</param>
/// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
/// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
ref LLamaTokenDataArrayNative candidates,
llama_token* last_tokens, ulong last_tokens_size,
float penalty_repeat,
float penalty_freq,
float penalty_present);

/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806


+ 10
- 0
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -102,5 +102,15 @@ namespace LLama.Native
return new(grammar_ptr);
}
#endregion

/// <summary>
/// Accepts the sampled token into the grammar
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
NativeApi.llama_grammar_accept_token(ctx, this, token);
}
}
}

+ 23
- 83
LLama/Native/SamplingApi.cs View File

@@ -9,7 +9,7 @@ namespace LLama.Native
/// <summary>
/// Direct translation of the llama.cpp sampling API
/// </summary>
public unsafe class SamplingApi
public class SamplingApi
{
/// <summary>
/// Apply grammar rules to candidate tokens
@@ -17,70 +17,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="grammar"></param>
[Obsolete("use LLamaTokenDataArray ApplyGrammar method")]
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
{
llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="penalty"></param>
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
}

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{
llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
}

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
candidates.ApplyGrammar(ctx, grammar);
}

/// <summary>
@@ -88,10 +28,10 @@ namespace LLama.Native
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
[Obsolete("use LLamaTokenDataArray Softmax method")]
public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_softmax(ctx, ref st);
candidates.Softmax(ctx);
}

/// <summary>
@@ -101,10 +41,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="k"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray TopK method")]
public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
candidates.TopK(ctx, k, min_keep);
}

/// <summary>
@@ -114,10 +54,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray TopP method")]
public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
candidates.TopP(ctx, p, min_keep);
}

/// <summary>
@@ -127,10 +67,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="z"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray TailFree method")]
public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
candidates.TailFree(ctx, z, min_keep);
}

/// <summary>
@@ -140,10 +80,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray LocallyTypical method")]
public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
candidates.LocallyTypical(ctx, p, min_keep);
}

/// <summary>
@@ -153,10 +93,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="temp"></param>
[Obsolete("use LLamaTokenDataArray Temperature() method")]
public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_temperature(ctx, ref st, temp);
candidates.Temperature(ctx, temp);
}

/// <summary>
@@ -169,10 +109,10 @@ namespace LLama.Native
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")]
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu);
}

/// <summary>
@@ -184,10 +124,10 @@ namespace LLama.Native
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")]
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu);
}

/// <summary>
@@ -196,10 +136,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")]
public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_greedy(ctx, ref st);
return candidates.SampleTokenGreedy(ctx);
}

/// <summary>
@@ -208,10 +148,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleToken() method")]
public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token(ctx, ref st);
return candidates.SampleToken(ctx);
}
}
}

Loading…
Cancel
Save