Browse Source

It works!

had to update binary to `b1426`
tags/v0.7.0^2
Martin Evans 2 years ago
parent
commit
a024d2242e
10 changed files with 246 additions and 39 deletions
  1. +1
    -0
      LLama.Examples/LLama.Examples.csproj
  2. +0
    -25
      LLama.Examples/NewVersion/BatchedBench.cs
  3. +219
    -0
      LLama.Examples/NewVersion/BatchedDecoding.cs
  4. +1
    -1
      LLama.Examples/NewVersion/TestRunner.cs
  5. +1
    -1
      LLama/LLamaContext.cs
  6. +1
    -1
      LLama/LLamaInstructExecutor.cs
  7. +3
    -3
      LLama/LLamaInteractExecutor.cs
  8. +1
    -1
      LLama/Native/LLamaBatchSafeHandle.cs
  9. +15
    -3
      LLama/Native/LLamaNativeBatch.cs
  10. +4
    -4
      LLama/Native/NativeApi.cs

+ 1
- 0
LLama.Examples/LLama.Examples.csproj View File

@@ -8,6 +8,7 @@
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
<!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults --> <!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults -->
<IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes> <IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup> </PropertyGroup>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 0
- 25
LLama.Examples/NewVersion/BatchedBench.cs View File

@@ -1,25 +0,0 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Examples.NewVersion;

public class BatchedBench
{
public static async Task Run()
{
Console.Write("Please input your model path: ");
//todo:var modelPath = Console.ReadLine();
var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf";

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

parameters.ContextSize = (uint)model.ContextSize;
using var context = model.CreateContext(parameters);

var n_kv_max = 1024;

using var batch = LLamaBatchSafeHandle.Create(n_kv_max, 0, 1);

}
}

+ 219
- 0
LLama.Examples/NewVersion/BatchedDecoding.cs View File

@@ -0,0 +1,219 @@
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using LLama.Common;
using LLama.Native;

namespace LLama.Examples.NewVersion;

public class BatchedDecoding
{
private const int n_parallel = 8;
private const int n_len = 32;

public static async Task Run()
{
Console.Write("Please input your model path: ");
//todo:var modelPath = Console.ReadLine();
var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf";

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "I would like to tell you about";

// Load model
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

// Tokenize prompt
var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;

// Create a context
parameters.ContextSize = (uint)model.ContextSize;
parameters.Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MinValue, int.MaxValue));
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
using var context = model.CreateContext(parameters);

var n_ctx = context.ContextSize;

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx)
{
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
return;
}

using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false);
Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length);

// llama_decode will output logits only for the last token of the prompt
unsafe
{
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
}

if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
}

// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
}

if (n_parallel > 1)
{
Console.WriteLine();
Console.WriteLine($"generating {n_parallel} sequences...");
}

// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.NativeBatch.n_tokens - 1);

int n_cur = batch.NativeBatch.n_tokens;
int n_decode = 0;

var streams = new List<int>[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new();

var eos = NativeApi.llama_token_eos(model.NativeHandle);
var nl = NativeApi.llama_token_nl(model.NativeHandle);

var timer = new Stopwatch();
timer.Start();
while (n_cur <= n_len)
{
llama_batch_clear(batch);

for (var i = 0; i < n_parallel; i++)
{
// Skip completed streams
if (i_batch[i] < 0)
continue;

unsafe
{
var n_vocab = model.VocabCount;
var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]);

var candidates = new LLamaTokenData[n_vocab];
for (var token_id = 0; token_id < n_vocab; token_id++)
{
candidates[token_id] = new LLamaTokenData
{
id = token_id,
logit = logits[token_id]
};
}

var candidates_p = new LLamaTokenDataArray(candidates);
using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native);

const int top_k = 40;
const float top_p = 0.9f;
const float temp = 0.4f;

NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1);
NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1);
NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp);

var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native);

if (new_token_id == eos || new_token_id == nl)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}

streams[i].Add(new_token_id);

i_batch[i] = batch.NativeBatch.n_tokens;

// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);

n_decode++;
}
}

// all streams are finished
if (batch.NativeBatch.n_tokens == 0)
{
break;
}

n_cur++;

// evaluate the current batch with the transformer model
if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
}
}

timer.Stop();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine();
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");

var index = 0;
foreach (var stream in streams)
{
var text = context.DeTokenize(stream);

Console.ForegroundColor = ConsoleColor.Green;
Console.Write($"{index++}. {prompt}");
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine(text);
}
}

/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
/// </summary>
private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
{
unsafe
{
ref var batch = ref batchHandle.NativeBatch;

batch.token[batch.n_tokens] = token;
batch.pos[batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = sequences.Count;

for (var i = 0; i < sequences.Count; i++)
batch.seq_id[batch.n_tokens][i] = sequences[i];

batch.logits[batch.n_tokens] = Convert.ToByte(logits);

batch.n_tokens++;
}
}

/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
/// </summary>
/// <param name="batchHandle"></param>
private static void llama_batch_clear(LLamaBatchSafeHandle batchHandle)
{
batchHandle.NativeBatch.n_tokens = 0;
}
}

+ 1
- 1
LLama.Examples/NewVersion/TestRunner.cs View File

@@ -91,7 +91,7 @@
} }
else if (choice == 15) else if (choice == 15)
{ {
await BatchedBench.Run();
await BatchedDecoding.Run();
} }
else else
{ {


+ 1
- 1
LLama/LLamaContext.cs View File

@@ -305,7 +305,7 @@ namespace LLama
} }


// Save the newline logit value // Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle);
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token]; var nl_logit = logits[nl_token];


// Convert logits into token candidates // Convert logits into token candidates


+ 1
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -163,7 +163,7 @@ namespace LLama
} }
} }


if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{ {
args.WaitForInput = true; args.WaitForInput = true;
} }


+ 3
- 3
LLama/LLamaInteractExecutor.cs View File

@@ -30,7 +30,7 @@ namespace LLama
public InteractiveExecutor(LLamaContext context, ILogger? logger = null) public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger) : base(context, logger)
{ {
_llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle);
_llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
} }


/// <inheritdoc /> /// <inheritdoc />
@@ -141,7 +141,7 @@ namespace LLama
return (true, Array.Empty<string>()); return (true, Array.Empty<string>());
} }


if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{ {
return (true, new[] { " [end of text]\n" }); return (true, new[] { " [end of text]\n" });
} }
@@ -202,7 +202,7 @@ namespace LLama


_last_n_tokens.Enqueue(id); _last_n_tokens.Enqueue(id);


if (id == NativeApi.llama_token_eos(Context.NativeHandle))
if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{ {
id = _llama_token_newline; id = _llama_token_newline;
if (args.Antiprompts is not null && args.Antiprompts.Count > 0) if (args.Antiprompts is not null && args.Antiprompts.Count > 0)


+ 1
- 1
LLama/Native/LLamaBatchSafeHandle.cs View File

@@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle
/// <summary> /// <summary>
/// Get the native llama_batch struct /// Get the native llama_batch struct
/// </summary> /// </summary>
public LLamaNativeBatch NativeBatch { get; private set; }
public LLamaNativeBatch NativeBatch;


/// <summary> /// <summary>
/// the token ids of the input (used when embd is NULL) /// the token ids of the input (used when embd is NULL)


+ 15
- 3
LLama/Native/LLamaNativeBatch.cs View File

@@ -11,12 +11,12 @@ using llama_token = Int32;
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
/// </summary> /// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public readonly unsafe struct LLamaNativeBatch
public unsafe struct LLamaNativeBatch
{ {
/// <summary> /// <summary>
/// The number of items pointed at by pos, seq_id and logits. /// The number of items pointed at by pos, seq_id and logits.
/// </summary> /// </summary>
public readonly int n_tokens;
public int n_tokens;


/// <summary> /// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
@@ -33,13 +33,25 @@ public readonly unsafe struct LLamaNativeBatch
/// </summary> /// </summary>
public readonly LLamaPos* pos; public readonly LLamaPos* pos;


/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ???
/// </summary>
public readonly int* n_seq_id;

/// <summary> /// <summary>
/// the sequence to which the respective token belongs /// the sequence to which the respective token belongs
/// </summary> /// </summary>
public readonly LLamaSeqId* seq_id;
public readonly LLamaSeqId** seq_id;


/// <summary> /// <summary>
/// if zero, the logits for the respective token will not be output /// if zero, the logits for the respective token will not be output
/// </summary> /// </summary>
public readonly byte* logits; public readonly byte* logits;

// Note from llama.cpp:
// > helpers for smooth API transition - can be deprecated in the future
// > for future-proof code, use the above fields instead and ignore everything below
private LLamaPos _all_pos_0;
private LLamaPos _all_pos_1;
private LLamaSeqId _all_seq_id;
} }

+ 4
- 4
LLama/Native/NativeApi.cs View File

@@ -350,7 +350,7 @@ namespace LLama.Native
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx);
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);


/// <summary> /// <summary>
/// Get the embeddings for the input /// Get the embeddings for the input
@@ -366,21 +366,21 @@ namespace LLama.Native
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);
public static extern llama_token llama_token_bos(SafeLlamaModelHandle model);


/// <summary> /// <summary>
/// Get the "End of sentence" token /// Get the "End of sentence" token
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);
public static extern llama_token llama_token_eos(SafeLlamaModelHandle model);


/// <summary> /// <summary>
/// Get the "new line" token /// Get the "new line" token
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);
public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);


/// <summary> /// <summary>
/// Print out timing information for this context /// Print out timing information for this context


Loading…
Cancel
Save