Browse Source

- Most importantly: Fixed issue in `SamplingApi`, `Memory` was pinned, but never unpinned!

- Moved repeated code to convert `LLamaTokenDataArray` into a `LLamaTokenDataArrayNative` into a helper method.
   - Modified all call sites to dispose the `MemoryHandle`
 - Saved one copy of the `List<LLamaTokenData>` into a `LLamaTokenData[]` in `LlamaModel`
tags/v0.4.2-preview
Martin Evans 2 years ago
parent
commit
ec49bdd6eb
5 changed files with 108 additions and 118 deletions
  1. +3
    -7
      LLama/LLamaModel.cs
  2. +59
    -11
      LLama/Native/LLamaTokenDataArray.cs
  3. +19
    -14
      LLama/Native/NativeApi.Sampling.cs
  4. +24
    -79
      LLama/Native/SamplingApi.cs
  5. +3
    -7
      LLama/OldVersion/LLamaModel.cs

+ 3
- 7
LLama/LLamaModel.cs View File

@@ -294,14 +294,10 @@ namespace LLama
} }
} }


var candidates = new List<LLamaTokenData>();
candidates.Capacity = n_vocab;
var candidates = new LLamaTokenData[n_vocab];
for (llama_token token_id = 0; token_id < n_vocab; token_id++) for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f));
}

LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false);
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);


// Apply penalties // Apply penalties
float nl_logit = logits[NativeApi.llama_token_nl()]; float nl_logit = logits[NativeApi.llama_token_nl()];


+ 59
- 11
LLama/Native/LLamaTokenDataArray.cs View File

@@ -1,32 +1,80 @@
using System; using System;
using System.Buffers; using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text;


namespace LLama.Native namespace LLama.Native
{ {
[StructLayout(LayoutKind.Sequential)]
/// <summary>
/// Contains an array of LLamaTokenData, potentially sorted.
/// </summary>
public struct LLamaTokenDataArray public struct LLamaTokenDataArray
{ {
public Memory<LLamaTokenData> data;
public ulong size;
[MarshalAs(UnmanagedType.I1)]
public bool sorted;
/// <summary>
/// The LLamaTokenData
/// </summary>
public readonly Memory<LLamaTokenData> data;


public LLamaTokenDataArray(LLamaTokenData[] data, ulong size, bool sorted)
/// <summary>
/// Indicates if `data` is sorted
/// </summary>
public readonly bool sorted;

/// <summary>
/// Create a new LLamaTokenDataArray
/// </summary>
/// <param name="tokens"></param>
/// <param name="isSorted"></param>
public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
{ {
this.data = data;
this.size = size;
this.sorted = sorted;
data = tokens;
sorted = isSorted;
} }
} }


/// <summary>
/// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
/// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct LLamaTokenDataArrayNative public struct LLamaTokenDataArrayNative
{ {
/// <summary>
/// A pointer to an array of LlamaTokenData
/// </summary>
/// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks>
public IntPtr data; public IntPtr data;

/// <summary>
/// Number of LLamaTokenData in the array
/// </summary>
public ulong size; public ulong size;

/// <summary>
/// Indicates if the items in the array are sorted
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool sorted; public bool sorted;

/// <summary>
/// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
/// </summary>
/// <param name="array">Data source</param>
/// <param name="native">Created native array</param>
/// <returns>A memory handle, pinning the data in place until disposed</returns>
public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native)
{
var handle = array.data.Pin();

unsafe
{
native = new LLamaTokenDataArrayNative
{
data = new IntPtr(handle.Pointer),
size = (ulong)array.data.Length,
sorted = array.sorted
};
}

return handle;
}
} }
} }

+ 19
- 14
LLama/Native/NativeApi.Sampling.cs View File

@@ -1,11 +1,10 @@
using System; using System;
using System.Collections.Generic;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text;


namespace LLama.Native namespace LLama.Native
{ {
using llama_token = Int32; using llama_token = Int32;

public unsafe partial class NativeApi public unsafe partial class NativeApi
{ {
/// <summary> /// <summary>
@@ -17,7 +16,7 @@ namespace LLama.Native
/// <param name="last_tokens_size"></param> /// <param name="last_tokens_size"></param>
/// <param name="penalty"></param> /// <param name="penalty"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty);
public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty);


/// <summary> /// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
@@ -29,7 +28,7 @@ namespace LLama.Native
/// <param name="alpha_frequency"></param> /// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param> /// <param name="alpha_presence"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);


/// <summary> /// <summary>
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
@@ -37,7 +36,7 @@ namespace LLama.Native
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param> /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, IntPtr candidates);
public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);


/// <summary> /// <summary>
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
@@ -47,7 +46,7 @@ namespace LLama.Native
/// <param name="k"></param> /// <param name="k"></param>
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, IntPtr candidates, int k, ulong min_keep);
public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, int k, ulong min_keep);


/// <summary> /// <summary>
/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
@@ -57,7 +56,7 @@ namespace LLama.Native
/// <param name="p"></param> /// <param name="p"></param>
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep);
public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);


/// <summary> /// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
@@ -67,7 +66,7 @@ namespace LLama.Native
/// <param name="z"></param> /// <param name="z"></param>
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, IntPtr candidates, float z, ulong min_keep);
public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float z, ulong min_keep);


/// <summary> /// <summary>
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
@@ -77,10 +76,16 @@ namespace LLama.Native
/// <param name="p"></param> /// <param name="p"></param>
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep);
public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);


/// <summary>
/// Modify logits by temperature
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="temp"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, IntPtr candidates, float temp);
public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float temp);


/// <summary> /// <summary>
/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -93,7 +98,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float* mu);
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu);


/// <summary> /// <summary>
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -105,7 +110,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float* mu);
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu);


/// <summary> /// <summary>
/// Selects the token with the highest probability. /// Selects the token with the highest probability.
@@ -114,7 +119,7 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param> /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, IntPtr candidates);
public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);


/// <summary> /// <summary>
/// Randomly selects a token from the candidates based on their probabilities. /// Randomly selects a token from the candidates based on their probabilities.
@@ -123,6 +128,6 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param> /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, IntPtr candidates);
public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
} }
} }

+ 24
- 79
LLama/Native/SamplingApi.cs View File

@@ -1,7 +1,4 @@
using System; using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;


namespace LLama.Native namespace LLama.Native
{ {
@@ -18,12 +15,8 @@ namespace LLama.Native
/// <param name="penalty"></param> /// <param name="penalty"></param>
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_repetition_penalty(ctx, new IntPtr(&st), last_tokens, last_tokens_size, penalty);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty);
} }


/// <summary> /// <summary>
@@ -37,12 +30,8 @@ namespace LLama.Native
/// <param name="alpha_presence"></param> /// <param name="alpha_presence"></param>
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, new IntPtr(&st), last_tokens, last_tokens_size, alpha_frequency, alpha_presence);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence);
} }


/// <summary> /// <summary>
@@ -52,12 +41,8 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param> /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_softmax(ctx, new IntPtr(&st));
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_softmax(ctx, ref st);
} }


/// <summary> /// <summary>
@@ -69,12 +54,8 @@ namespace LLama.Native
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_top_k(ctx, new IntPtr(&st), k, min_keep);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
} }


/// <summary> /// <summary>
@@ -86,12 +67,8 @@ namespace LLama.Native
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_top_p(ctx, new IntPtr(&st), p, min_keep);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
} }


/// <summary> /// <summary>
@@ -103,12 +80,8 @@ namespace LLama.Native
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_tail_free(ctx, new IntPtr(&st), z, min_keep);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
} }


/// <summary> /// <summary>
@@ -120,22 +93,14 @@ namespace LLama.Native
/// <param name="min_keep"></param> /// <param name="min_keep"></param>
public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_typical(ctx, new IntPtr(&st), p, min_keep);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
} }


public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
NativeApi.llama_sample_temperature(ctx, new IntPtr(&st), temp);
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_temperature(ctx, ref st, temp);
} }


/// <summary> /// <summary>
@@ -150,17 +115,11 @@ namespace LLama.Native
/// <returns></returns> /// <returns></returns>
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
llama_token res;
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
fixed(float* pmu = &mu) fixed(float* pmu = &mu)
{ {
res = NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, pmu);
return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu);
} }
return res;
} }


/// <summary> /// <summary>
@@ -174,17 +133,11 @@ namespace LLama.Native
/// <returns></returns> /// <returns></returns>
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
llama_token res;
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
fixed (float* pmu = &mu) fixed (float* pmu = &mu)
{ {
res = NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, pmu);
return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu);
} }
return res;
} }


/// <summary> /// <summary>
@@ -195,12 +148,8 @@ namespace LLama.Native
/// <returns></returns> /// <returns></returns>
public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
return NativeApi.llama_sample_token_greedy(ctx, new IntPtr(&st));
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_greedy(ctx, ref st);
} }


/// <summary> /// <summary>
@@ -211,12 +160,8 @@ namespace LLama.Native
/// <returns></returns> /// <returns></returns>
public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{ {
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
return NativeApi.llama_sample_token(ctx, new IntPtr(&st));
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token(ctx, ref st);
} }
} }
} }

+ 3
- 7
LLama/OldVersion/LLamaModel.cs View File

@@ -632,14 +632,10 @@ namespace LLama.OldVersion
logits[key] += value; logits[key] += value;
} }


var candidates = new List<LLamaTokenData>();
candidates.Capacity = n_vocab;
var candidates = new LLamaTokenData[n_vocab];
for (llama_token token_id = 0; token_id < n_vocab; token_id++) for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f));
}

LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false);
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);


// Apply penalties // Apply penalties
float nl_logit = logits[NativeApi.llama_token_nl()]; float nl_logit = logits[NativeApi.llama_token_nl()];


Loading…
Cancel
Save