Browse Source

- Converted LLamaStatelessExecutor to run `Exec` calls inside an awaited task. This unblocks async callers while the model is being evaluated.

- Added a "spinner" to the `StatelessModeExecute` demo, which spins while waiting for the next token (demonstrating that it's not blocked).
tags/v0.6.0
Martin Evans 2 years ago
parent
commit
08f1615e60
4 changed files with 59 additions and 22 deletions
  1. +39
    -1
      LLama.Examples/NewVersion/StatelessModeExecute.cs
  2. +7
    -0
      LLama/Extensions/IReadOnlyListExtensions.cs
  3. +4
    -4
      LLama/LLamaContext.cs
  4. +9
    -17
      LLama/LLamaStatelessExecutor.cs

+ 39
- 1
LLama.Examples/NewVersion/StatelessModeExecute.cs View File

@@ -35,11 +35,49 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
await foreach (var text in Spinner(ex.InferAsync(prompt, inferenceParams)))
{
Console.Write(text);
}
}
}

/// <summary>
/// Show a spinner while waiting for the next result
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
private static async IAsyncEnumerable<string> Spinner(IAsyncEnumerable<string> source)
{
var enumerator = source.GetAsyncEnumerator();

var characters = new[] { '|', '/', '-', '\\' };

while (true)
{
var next = enumerator.MoveNextAsync();

var (Left, Top) = Console.GetCursorPosition();

// Keep showing the next spinner character while waiting for "MoveNextAsync" to finish
var count = 0;
while (!next.IsCompleted)
{
count = (count + 1) % characters.Length;
Console.SetCursorPosition(Left, Top);
Console.Write(characters[count]);
await Task.Delay(75);
}

// Clear the spinner character
Console.SetCursorPosition(Left, Top);
Console.Write(" ");
Console.SetCursorPosition(Left, Top);

if (!next.Result)
break;
yield return enumerator.Current;
}
}
}
}

+ 7
- 0
LLama/Extensions/IReadOnlyListExtensions.cs View File

@@ -68,6 +68,13 @@ namespace LLama.Extensions
}
}

internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{
return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding);
}

/// <summary>
/// Check if the given set of tokens ends with any of the given strings
/// </summary>


+ 4
- 4
LLama/LLamaContext.cs View File

@@ -406,7 +406,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(llama_token[] tokens, llama_token pastTokensCount)
public int Eval(llama_token[] tokens, int pastTokensCount)
{
return Eval(tokens.AsSpan(), pastTokensCount);
}
@@ -418,7 +418,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
public int Eval(List<llama_token> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
@@ -448,7 +448,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount)
public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}
@@ -460,7 +460,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount)
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += Params.BatchSize)


+ 9
- 17
LLama/LLamaStatelessExecutor.cs View File

@@ -5,6 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Extensions;

namespace LLama
@@ -73,7 +74,6 @@ namespace LLama
cancellationToken.ThrowIfCancellationRequested();

var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
var n_past = 1;
inferenceParams ??= new InferenceParams();

var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
@@ -81,12 +81,12 @@ namespace LLama
lastTokens.Add(0);

var tokens = Context.Tokenize(text).ToList();
var n_prompt_tokens = tokens.Count;

Context.Eval(tokens, n_past);
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);

lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
var n_past = 1 + tokens.Count;

var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
@@ -111,7 +111,8 @@ namespace LLama
tokens.Clear();
tokens.Add(id);

if (EndsWithAntiprompt(lastTokens, antiprompts))
// Check if any of the antiprompts have been generated
if (tokens.TokensEndsWithAnyString(antiprompts, Context))
break;

// when run out of context
@@ -126,19 +127,10 @@ namespace LLama
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
}

n_past = Context.Eval(tokens, n_past);
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
}
}

/// <summary>
/// Check if the given tokens list ends with any of the antiprompts
/// </summary>
/// <param name="tokens"></param>
/// <param name="antiprompts"></param>
/// <returns></returns>
private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts)
{
return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
}
}
}

Loading…
Cancel
Save