- Added a "spinner" to the `StatelessModeExecute` demo, which spins while waiting for the next token (demonstrating that it's not blocked).tags/v0.6.0
| @@ -35,11 +35,49 @@ namespace LLama.Examples.NewVersion | |||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| Console.Write("Answer: "); | Console.Write("Answer: "); | ||||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | prompt = $"Question: {prompt?.Trim()} Answer: "; | ||||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||||
| await foreach (var text in Spinner(ex.InferAsync(prompt, inferenceParams))) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Show a spinner while waiting for the next result | |||||
| /// </summary> | |||||
| /// <param name="source"></param> | |||||
| /// <returns></returns> | |||||
| private static async IAsyncEnumerable<string> Spinner(IAsyncEnumerable<string> source) | |||||
| { | |||||
| var enumerator = source.GetAsyncEnumerator(); | |||||
| var characters = new[] { '|', '/', '-', '\\' }; | |||||
| while (true) | |||||
| { | |||||
| var next = enumerator.MoveNextAsync(); | |||||
| var (Left, Top) = Console.GetCursorPosition(); | |||||
| // Keep showing the next spinner character while waiting for "MoveNextAsync" to finish | |||||
| var count = 0; | |||||
| while (!next.IsCompleted) | |||||
| { | |||||
| count = (count + 1) % characters.Length; | |||||
| Console.SetCursorPosition(Left, Top); | |||||
| Console.Write(characters[count]); | |||||
| await Task.Delay(75); | |||||
| } | |||||
| // Clear the spinner character | |||||
| Console.SetCursorPosition(Left, Top); | |||||
| Console.Write(" "); | |||||
| Console.SetCursorPosition(Left, Top); | |||||
| if (!next.Result) | |||||
| break; | |||||
| yield return enumerator.Current; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -68,6 +68,13 @@ namespace LLama.Extensions | |||||
| } | } | ||||
| } | } | ||||
| internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context) | |||||
| where TTokens : IReadOnlyList<int> | |||||
| where TQueries : IReadOnlyList<string> | |||||
| { | |||||
| return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Check if the given set of tokens ends with any of the given strings | /// Check if the given set of tokens ends with any of the given strings | ||||
| /// </summary> | /// </summary> | ||||
| @@ -406,7 +406,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public int Eval(llama_token[] tokens, llama_token pastTokensCount) | |||||
| public int Eval(llama_token[] tokens, int pastTokensCount) | |||||
| { | { | ||||
| return Eval(tokens.AsSpan(), pastTokensCount); | return Eval(tokens.AsSpan(), pastTokensCount); | ||||
| } | } | ||||
| @@ -418,7 +418,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public int Eval(List<llama_token> tokens, llama_token pastTokensCount) | |||||
| public int Eval(List<llama_token> tokens, int pastTokensCount) | |||||
| { | { | ||||
| #if NET5_0_OR_GREATER | #if NET5_0_OR_GREATER | ||||
| var span = CollectionsMarshal.AsSpan(tokens); | var span = CollectionsMarshal.AsSpan(tokens); | ||||
| @@ -448,7 +448,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount) | |||||
| public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount) | |||||
| { | { | ||||
| return Eval(tokens.Span, pastTokensCount); | return Eval(tokens.Span, pastTokensCount); | ||||
| } | } | ||||
| @@ -460,7 +460,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount) | |||||
| public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount) | |||||
| { | { | ||||
| var total = tokens.Length; | var total = tokens.Length; | ||||
| for(var i = 0; i < total; i += Params.BatchSize) | for(var i = 0; i < total; i += Params.BatchSize) | ||||
| @@ -5,6 +5,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | |||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -73,7 +74,6 @@ namespace LLama | |||||
| cancellationToken.ThrowIfCancellationRequested(); | cancellationToken.ThrowIfCancellationRequested(); | ||||
| var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>(); | var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>(); | ||||
| var n_past = 1; | |||||
| inferenceParams ??= new InferenceParams(); | inferenceParams ??= new InferenceParams(); | ||||
| var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount); | var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount); | ||||
| @@ -81,12 +81,12 @@ namespace LLama | |||||
| lastTokens.Add(0); | lastTokens.Add(0); | ||||
| var tokens = Context.Tokenize(text).ToList(); | var tokens = Context.Tokenize(text).ToList(); | ||||
| var n_prompt_tokens = tokens.Count; | |||||
| Context.Eval(tokens, n_past); | |||||
| await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) | |||||
| .ConfigureAwait(false); | |||||
| lastTokens.AddRange(tokens); | lastTokens.AddRange(tokens); | ||||
| n_past += n_prompt_tokens; | |||||
| var n_past = 1 + tokens.Count; | |||||
| var mu = (float?)null; | var mu = (float?)null; | ||||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | ||||
| @@ -111,7 +111,8 @@ namespace LLama | |||||
| tokens.Clear(); | tokens.Clear(); | ||||
| tokens.Add(id); | tokens.Add(id); | ||||
| if (EndsWithAntiprompt(lastTokens, antiprompts)) | |||||
| // Check if any of the antiprompts have been generated | |||||
| if (tokens.TokensEndsWithAnyString(antiprompts, Context)) | |||||
| break; | break; | ||||
| // when run out of context | // when run out of context | ||||
| @@ -126,19 +127,10 @@ namespace LLama | |||||
| tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); | tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); | ||||
| } | } | ||||
| n_past = Context.Eval(tokens, n_past); | |||||
| // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | |||||
| n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) | |||||
| .ConfigureAwait(false); | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Check if the given tokens list ends with any of the antiprompts | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="antiprompts"></param> | |||||
| /// <returns></returns> | |||||
| private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts) | |||||
| { | |||||
| return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||