From 629430a087be3c69df2e98d2865272882edcd0fd Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Thu, 16 Nov 2023 14:09:14 -0600 Subject: [PATCH 01/22] Correctly format followup messages in turn-based (chat) inference --- LLama/ChatSession.cs | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 7ee99590..748d2ef3 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -159,15 +159,15 @@ namespace LLama InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); prompt = state.IsPromptRun ? HistoryTransform.HistoryToText(History) - : prompt; + : HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt)); } StringBuilder sb = new(); - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { - yield return result; - sb.Append(result); + yield return textToken; + sb.Append(textToken); } string assistantMessage = sb.ToString(); @@ -180,7 +180,7 @@ namespace LLama { foreach (var stopToken in inferenceParams.AntiPrompts) { - assistantMessage = assistantMessage.Replace(stopToken, ""); + assistantMessage = assistantMessage.Replace(stopToken, "").Trim(); } } @@ -209,27 +209,37 @@ namespace LLama { InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); - prompt = state.IsPromptRun - ? HistoryTransform.HistoryToText(History) - : history.Messages.Last().Content; + if (state.IsPromptRun) + { + prompt = HistoryTransform.HistoryToText(History); + } + else + { + ChatHistory.Message lastMessage = history.Messages.Last(); + prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); + } } else { - prompt = history.Messages.Last().Content; + ChatHistory.Message lastMessage = history.Messages.Last(); + prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); } - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { - yield return result; + yield return textToken; } } private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Console.ForegroundColor = ConsoleColor.Gray; + Console.WriteLine(prompt); + var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); - await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) + await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) { - yield return item; + yield return textToken; } } } From 75932afc62993f2ef4132ec36975175749ba5ebd Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Thu, 16 Nov 2023 14:25:41 -0600 Subject: [PATCH 02/22] Remove debug output --- LLama/ChatSession.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 748d2ef3..d1504a08 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -233,9 +233,6 @@ namespace LLama private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - Console.ForegroundColor = ConsoleColor.Gray; - Console.WriteLine(prompt); - var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) { From 33358124db7f692b6f73070caffa1da03e368934 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 4 Dec 2023 01:31:42 +0000 Subject: [PATCH 03/22] Initial pass at a new sampling pipeline --- LLama/Native/LLamaTokenDataArray.cs | 10 +- LLama/Sampling/ISamplingPipeline.cs | 99 +++++++++++++++++ LLama/Sampling/Logits/ILogitProcessor.cs | 34 ++++++ LLama/Sampling/Logits/LogitBias.cs | 39 +++++++ LLama/Sampling/Logits/SaveLoad.cs | 100 ++++++++++++++++++ LLama/Sampling/Selection/GreedySelection.cs | 27 +++++ LLama/Sampling/Selection/ITokenSelector.cs | 25 +++++ .../Sampling/Selection/Mirostat2Selection.cs | 65 ++++++++++++ LLama/Sampling/Selection/MirostatSelection.cs | 76 +++++++++++++ LLama/Sampling/Selection/StandardSelection.cs | 27 +++++ LLama/Sampling/Tokens/GrammarSampling.cs | 59 +++++++++++ LLama/Sampling/Tokens/ITokenDataProcessor.cs | 34 ++++++ .../Sampling/Tokens/LocallyTypicalSampling.cs | 42 ++++++++ LLama/Sampling/Tokens/MinPSampling.cs | 42 ++++++++ LLama/Sampling/Tokens/RepetitionPenalty.cs | 77 ++++++++++++++ LLama/Sampling/Tokens/TailFreeSampling.cs | 42 ++++++++ LLama/Sampling/Tokens/TemperatureSampling.cs | 38 +++++++ LLama/Sampling/Tokens/TopKSampling.cs | 38 +++++++ LLama/Sampling/Tokens/TopPSampling.cs | 42 ++++++++ 19 files changed, 912 insertions(+), 4 deletions(-) create mode 100644 LLama/Sampling/ISamplingPipeline.cs create mode 100644 LLama/Sampling/Logits/ILogitProcessor.cs create mode 100644 LLama/Sampling/Logits/LogitBias.cs create mode 100644 LLama/Sampling/Logits/SaveLoad.cs create mode 100644 LLama/Sampling/Selection/GreedySelection.cs create mode 100644 LLama/Sampling/Selection/ITokenSelector.cs create mode 100644 LLama/Sampling/Selection/Mirostat2Selection.cs create mode 100644 LLama/Sampling/Selection/MirostatSelection.cs create mode 100644 LLama/Sampling/Selection/StandardSelection.cs create mode 100644 LLama/Sampling/Tokens/GrammarSampling.cs create mode 100644 LLama/Sampling/Tokens/ITokenDataProcessor.cs create mode 100644 LLama/Sampling/Tokens/LocallyTypicalSampling.cs create mode 100644 LLama/Sampling/Tokens/MinPSampling.cs create mode 100644 LLama/Sampling/Tokens/RepetitionPenalty.cs create mode 100644 LLama/Sampling/Tokens/TailFreeSampling.cs create mode 100644 LLama/Sampling/Tokens/TemperatureSampling.cs create mode 100644 LLama/Sampling/Tokens/TopKSampling.cs create mode 100644 LLama/Sampling/Tokens/TopPSampling.cs diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 4bc154f4..897cf8b8 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -145,15 +145,17 @@ namespace LLama.Native /// /// /// - public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) - using (var last_tokens_handle = last_tokens.Pin()) { - NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); - sorted = st.sorted; + fixed (int* last_tokens_handle = last_tokens) + { + NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); + sorted = st.sorted; + } } } } diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs new file mode 100644 index 00000000..489f2c5a --- /dev/null +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using LLama.Native; +using LLama.Sampling.Logits; +using LLama.Sampling.Selection; +using LLama.Sampling.Tokens; + +namespace LLama.Sampling; + +/// +/// Convert a span of logits into a single sampled token +/// +public interface ISamplingPipeline + : IDisposable +{ + /// + /// Sample a single token from the given logits + /// + /// + /// + /// + /// + int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Reset all internal state of the sampling pipeline + /// + void Reset(); +} + +/// +/// Simple implementation of `ISamplingPipeline`, applies processors in order every time +/// +public sealed class BasicSamplingPipeline + : ISamplingPipeline +{ + /// + /// Logit processors to apply in this pipeline + /// + public IList LogitProcessors { get; } = new List(); + + /// + /// Token data processors to apply in this pipeline + /// + public IList TokenDataProcessors { get; } = new List(); + + /// + /// The selector to choose the final token + /// + public ITokenSelector Selector { get; set; } = new StandardSelection(); + + /// + public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + // Modify raw logits + foreach (var logitProcessor in LogitProcessors) + logitProcessor.ProcessLogits(ctx, logits, lastTokens); + + // Convert logits into token candidates + var candidates_p = LLamaTokenDataArray.Create(logits); + + // Process token candidates + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens); + + // Select a token + var token = Selector.Select(ctx, candidates_p, lastTokens); + + // Tell processors what was selected + foreach (var logitProcessor in LogitProcessors) + logitProcessor.AcceptToken(ctx, token); + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.AcceptToken(ctx, token); + + return token; + } + + /// + public void Reset() + { + foreach (var logitProcessor in LogitProcessors) + logitProcessor.Reset(); + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.Reset(); + + Selector.Reset(); + } + + /// + public void Dispose() + { + foreach (var logitProcessor in LogitProcessors) + logitProcessor.Dispose(); + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.Dispose(); + + Selector.Dispose(); + } +} \ No newline at end of file diff --git a/LLama/Sampling/Logits/ILogitProcessor.cs b/LLama/Sampling/Logits/ILogitProcessor.cs new file mode 100644 index 00000000..76968499 --- /dev/null +++ b/LLama/Sampling/Logits/ILogitProcessor.cs @@ -0,0 +1,34 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Logits; + +using llama_token = Int32; + +/// +/// Processes raw logits before sampling, applying penalties to certain tokens +/// +public interface ILogitProcessor + : IDisposable +{ + /// + /// Process raw logits, indexed by llama_token + /// + /// The context this is operating in + /// The token data array to process + /// The most recent tokens output + /// LLamaTokenDataArray, created from logits + void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Inform this process when a token is accepted by the model + /// + /// + /// + void AcceptToken(SafeLLamaContextHandle ctx, int token); + + /// + /// Reset all internal sampling state + /// + void Reset(); +} \ No newline at end of file diff --git a/LLama/Sampling/Logits/LogitBias.cs b/LLama/Sampling/Logits/LogitBias.cs new file mode 100644 index 00000000..fc821508 --- /dev/null +++ b/LLama/Sampling/Logits/LogitBias.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling.Logits; + +/// +/// Add a bias directly to logit values +/// +public sealed class LogitBias + : ILogitProcessor +{ + /// + /// Biases to apply, token -> bias + /// + public IDictionary Biases { get; } = new Dictionary(); + + /// + public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var kvp in Biases) + logits[kvp.Key] += kvp.Value; + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Logits/SaveLoad.cs b/LLama/Sampling/Logits/SaveLoad.cs new file mode 100644 index 00000000..6f80aec4 --- /dev/null +++ b/LLama/Sampling/Logits/SaveLoad.cs @@ -0,0 +1,100 @@ +using System; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling.Logits; + +/// +/// Save certain logit values +/// +public sealed class SaveLogitValues + : ILogitProcessor +{ + private readonly Dictionary _saved = new(); + + /// + /// Logits to save + /// + public ISet Logits { get; } = new HashSet(); + + /// + /// Saved logit values + /// + public IReadOnlyDictionary Values => _saved; + + /// + public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + _saved.Clear(); + foreach (var logit in Logits) + _saved[logit] = logits[logit]; + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + _saved.Clear(); + } + + /// + public void Dispose() + { + } + + /// + /// Get a logit processor that overwrite the logit values with the values saved here + /// + /// + public ILogitProcessor GetWriter() + { + return new LoadLogitValues(_saved); + } +} + +/// +/// Overwrite certain logit values +/// +public sealed class LoadLogitValues + : ILogitProcessor +{ + /// + /// Logits to overwrite, token -> logit + /// + public IDictionary Values { get; } + + /// + /// Create a new LoadLogitValues + /// + /// Source for values to overwrite + public LoadLogitValues(Dictionary? values = null) + { + Values = values ?? new Dictionary(); + } + + /// + public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var logit in Values) + logits[logit.Key] = logit.Value; + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/GreedySelection.cs b/LLama/Sampling/Selection/GreedySelection.cs new file mode 100644 index 00000000..30b72456 --- /dev/null +++ b/LLama/Sampling/Selection/GreedySelection.cs @@ -0,0 +1,27 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select the most likely token +/// +public sealed class GreedySelection + : ITokenSelector +{ + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenGreedy(ctx); + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/ITokenSelector.cs b/LLama/Sampling/Selection/ITokenSelector.cs new file mode 100644 index 00000000..c8915a92 --- /dev/null +++ b/LLama/Sampling/Selection/ITokenSelector.cs @@ -0,0 +1,25 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select a single token from a set of possibilities +/// +public interface ITokenSelector + : IDisposable +{ + /// + /// Select a single token + /// + /// + /// + /// + /// + int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + + /// + /// Reset the state + /// + void Reset(); +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/Mirostat2Selection.cs b/LLama/Sampling/Selection/Mirostat2Selection.cs new file mode 100644 index 00000000..cdc802c1 --- /dev/null +++ b/LLama/Sampling/Selection/Mirostat2Selection.cs @@ -0,0 +1,65 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select a token using Mirostat sampling. +/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. +/// +public sealed class Mirostat2Selection + : ITokenSelector +{ + private float _mu; + + /// + /// Current value of Mu, updated based on the difference between target surprise and actual surprise + /// + public float Mu + { + get => _mu; + set => _mu = value; + } + + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// + public float Tau { get; set; } + + /// + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// + public float Eta { get; set; } + + /// + /// Create a new Mirostat 2.0 sampler + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + public Mirostat2Selection(float tau, float eta) + { + Tau = tau; + Eta = eta; + } + + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); + } + + /// + public void Reset() + { + _mu = 2 * Tau; + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/MirostatSelection.cs b/LLama/Sampling/Selection/MirostatSelection.cs new file mode 100644 index 00000000..5ec34a13 --- /dev/null +++ b/LLama/Sampling/Selection/MirostatSelection.cs @@ -0,0 +1,76 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select a token using Mirostat sampling. +/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. +/// +public sealed class MirostatSelection + : ITokenSelector +{ + private float _mu; + + /// + /// Current value of Mu, updated based on the difference between target surprise and actual surprise + /// + public float Mu + { + get => _mu; + set => _mu = value; + } + + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// + public float Tau { get; set; } + + /// + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// + public float Eta { get; set; } + + /// + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn + /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects + /// the performance of the algorithm. + /// + public int M { get; set; } + + /// + /// Create a new Mirostat 2.0 sampler + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn + /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects + /// the performance of the algorithm. + public MirostatSelection(float tau, float eta, int m = 100) + { + Tau = tau; + Eta = eta; + M = m; + } + + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu); + } + + /// + public void Reset() + { + _mu = 2 * Tau; + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/StandardSelection.cs b/LLama/Sampling/Selection/StandardSelection.cs new file mode 100644 index 00000000..3e3bd086 --- /dev/null +++ b/LLama/Sampling/Selection/StandardSelection.cs @@ -0,0 +1,27 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select from all possible tokens according to their probability +/// +public sealed class StandardSelection + : ITokenSelector +{ + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleToken(ctx); + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/GrammarSampling.cs b/LLama/Sampling/Tokens/GrammarSampling.cs new file mode 100644 index 00000000..b823a7f9 --- /dev/null +++ b/LLama/Sampling/Tokens/GrammarSampling.cs @@ -0,0 +1,59 @@ +using System; +using LLama.Grammars; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Apply a grammar to prevent sampling tokens which do not match the grammar +/// +public sealed class GrammarSampling + : ITokenDataProcessor +{ + private SafeLLamaGrammarHandle? _handle; + + /// + /// Grammar to use for sampling + /// + public Grammar? Grammar { get; set; } + + /// + /// Create a new + /// + /// + public GrammarSampling(Grammar grammar) + { + Grammar = grammar; + } + + /// + public void Reset() + { + _handle?.Dispose(); + _handle = null; + } + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + // Create a new grammar instance if necessary + _handle ??= Grammar?.CreateInstance(); + + // Apply it + if (_handle != null) + tokens.ApplyGrammar(ctx, _handle); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + _handle?.AcceptToken(ctx, token); + } + + /// + public void Dispose() + { + _handle?.Dispose(); + _handle = null; + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/ITokenDataProcessor.cs b/LLama/Sampling/Tokens/ITokenDataProcessor.cs new file mode 100644 index 00000000..e6679cc2 --- /dev/null +++ b/LLama/Sampling/Tokens/ITokenDataProcessor.cs @@ -0,0 +1,34 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +using llama_token = Int32; + +/// +/// Processes token logits before sampling, applying penalties to certain tokens +/// +public interface ITokenDataProcessor + : IDisposable +{ + /// + /// Process token logits in a LLamaTokenDataArray + /// + /// The context this is operating in + /// The token data array to process + /// The most recent tokens output + /// LLamaTokenDataArray, created from logits + void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens); + + /// + /// Inform this process when a token is accepted by the model + /// + /// + /// + void AcceptToken(SafeLLamaContextHandle ctx, int token); + + /// + /// Reset all internal sampling state + /// + void Reset(); +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/LocallyTypicalSampling.cs b/LLama/Sampling/Tokens/LocallyTypicalSampling.cs new file mode 100644 index 00000000..3f602c9a --- /dev/null +++ b/LLama/Sampling/Tokens/LocallyTypicalSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. +/// +public sealed class LocallyTypicalSampling + : ITokenDataProcessor +{ + /// + /// P value for locally typical sampling + /// + public float P { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.LocallyTypical(ctx, P, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/MinPSampling.cs b/LLama/Sampling/Tokens/MinPSampling.cs new file mode 100644 index 00000000..c3adf026 --- /dev/null +++ b/LLama/Sampling/Tokens/MinPSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 +/// +public sealed class MinPSampling + : ITokenDataProcessor +{ + /// + /// All tokens with probability greater than this will be kept + /// + public float P { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.MinP(ctx, P, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/RepetitionPenalty.cs b/LLama/Sampling/Tokens/RepetitionPenalty.cs new file mode 100644 index 00000000..3cfdbcd4 --- /dev/null +++ b/LLama/Sampling/Tokens/RepetitionPenalty.cs @@ -0,0 +1,77 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. +/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. +/// +public sealed class RepetitionPenalty + : ITokenDataProcessor +{ + private float _alphaFreq; + private float _alphaPresence; + + /// + /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 + /// + public float RepeatPenalty { get; set; } = 1.1f; + + /// + /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text + /// so far, decreasing the model's likelihood to repeat the same line verbatim. + ///
+ public float AlphaFrequency + { + get => _alphaFreq; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaFreq = value; + } + } + + /// + /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + /// text so far, increasing the model's likelihood to talk about new topics. + ///
+ public float AlphaPresence + { + get => _alphaPresence; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaPresence = value; + } + } + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TailFreeSampling.cs b/LLama/Sampling/Tokens/TailFreeSampling.cs new file mode 100644 index 00000000..8e9fb2b5 --- /dev/null +++ b/LLama/Sampling/Tokens/TailFreeSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. +/// +public sealed class TailFreeSampling + : ITokenDataProcessor +{ + /// + /// Z value for tail free sampling + /// + public float Z { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.TailFree(ctx, Z, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TemperatureSampling.cs b/LLama/Sampling/Tokens/TemperatureSampling.cs new file mode 100644 index 00000000..0186f275 --- /dev/null +++ b/LLama/Sampling/Tokens/TemperatureSampling.cs @@ -0,0 +1,38 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Sample with temperature. +/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual +/// +public sealed class TemperatureSampling + : ITokenDataProcessor +{ + /// + /// Temperature value to apply + /// + public float Temperature { get; set; } = 0.5f; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.Temperature(ctx, Temperature); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopKSampling.cs b/LLama/Sampling/Tokens/TopKSampling.cs new file mode 100644 index 00000000..3f797c85 --- /dev/null +++ b/LLama/Sampling/Tokens/TopKSampling.cs @@ -0,0 +1,38 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Sample with TopK, removing all by the K most likely tokens. +/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +/// +public sealed class TopKSampling + : ITokenDataProcessor +{ + /// + /// Number of tokens to keep + /// + public int Count { get; set; } + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.TopK(ctx, Count); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopPSampling.cs b/LLama/Sampling/Tokens/TopPSampling.cs new file mode 100644 index 00000000..577ce3bc --- /dev/null +++ b/LLama/Sampling/Tokens/TopPSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +/// +public sealed class TopPSampling + : ITokenDataProcessor +{ + /// + /// P valies for TopP + /// + public float P { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.TopP(ctx, P, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file From f1eac82ecc4f49403ec06026aba3df7efdc8cb36 Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Sat, 25 Nov 2023 09:23:37 -0600 Subject: [PATCH 04/22] Update target frameworks with .NET 8 --- LLama.Examples/LLama.Examples.csproj | 2 +- LLama.KernelMemory/LLamaSharp.KernelMemory.csproj | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index b7369172..5053c038 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -2,7 +2,7 @@ Exe - net6.0 + net6.0;net7.0;net8.0 enable enable AnyCPU;x64 diff --git a/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj b/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj index 78d4712b..3867b7e1 100644 --- a/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj +++ b/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj @@ -1,7 +1,7 @@ - net6.0;net7.0 + net6.0;net7.0;net8.0 enable enable From cb480f04afdba27d0712b18c336ac80ad1a698e9 Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Sat, 25 Nov 2023 09:24:03 -0600 Subject: [PATCH 05/22] Prevent compilation errors due to duplicated assembly info --- LLama/LLamaSharp.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 0e029c2d..5e7de5f4 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -28,6 +28,7 @@ AnyCPU;x64;Arm64 LLamaSharp Debug;Release;GPU + false From 67e6d633fd3564e0016b08fbdf8ed63bc038c36a Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Sat, 25 Nov 2023 09:28:39 -0600 Subject: [PATCH 06/22] Rebuild ChatSession class - Saves with serialized ChatHistory of session - Only allows use of ChatHistory.Message (instead of raw text) for easy post-processing with IHistoryTransform implementation - Provides History Management methods - Allows user to regenerate last assistant message --- LLama/ChatSession.cs | 645 +++++++++++++++++++++++++----------- LLama/Common/ChatHistory.cs | 40 ++- 2 files changed, 483 insertions(+), 202 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index d1504a08..2985bd5f 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -1,243 +1,496 @@ -using LLama.Abstractions; -using LLama.Common; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using static LLama.InteractiveExecutor; -namespace LLama +namespace LLama; + +/// +/// The main chat session class. +/// +public class ChatSession { + private const string _modelStateFilename = "ModelState.st"; + private const string _executorStateFilename = "ExecutorState.json"; + private const string _hsitoryFilename = "ChatHistory.json"; + /// - /// The main chat session class. - /// - public class ChatSession - { - private readonly ILLamaExecutor _executor; - private readonly ChatHistory _history; - - private const string _executorStateFilename = "ExecutorState.json"; - private const string _modelStateFilename = "ModelState.st"; - - /// - /// The executor for this session. - /// - public ILLamaExecutor Executor => _executor; - /// - /// The chat history for this session. - /// - public ChatHistory History => _history; - /// - /// The history transform used in this session. - /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); - /// - /// The input transform pipeline used in this session. - /// - public List InputTransformPipeline { get; set; } = new(); - /// - /// The output transform used in this session. - /// - public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); - - /// - /// - /// - /// The executor for this session - public ChatSession(ILLamaExecutor executor) - { - _executor = executor; - _history = new ChatHistory(); - } - - /// - /// Use a custom history transform. - /// - /// - /// - public ChatSession WithHistoryTransform(IHistoryTransform transform) - { - HistoryTransform = transform; - return this; - } - - /// - /// Add a text transform to the input transform pipeline. - /// - /// - /// - public ChatSession AddInputTransform(ITextTransform transform) - { - InputTransformPipeline.Add(transform); - return this; - } - - /// - /// Use a custom output transform. - /// - /// - /// - public ChatSession WithOutputTransform(ITextStreamTransform transform) - { - OutputTransform = transform; - return this; - } - - /// - /// - /// - /// The directory name to save the session. If the directory does not exist, a new directory will be created. - public virtual void SaveSession(string path) - { - if (!Directory.Exists(path)) - { - Directory.CreateDirectory(path); - } - _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); - if (Executor is StatelessExecutor) - { + /// The executor for this session. + /// + public ILLamaExecutor Executor { get; private set; } - } - else if (Executor is StatefulExecutorBase statefulExecutor) - { - statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); - } - else - { - throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); - } + /// + /// The chat history for this session. + /// + public ChatHistory History { get; private set; } = new(); + + /// + /// The history transform used in this session. + /// + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + + /// + /// The input transform pipeline used in this session. + /// + public List InputTransformPipeline { get; set; } = new(); + + /// + /// The output transform used in this session. + /// + public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); + + /// + /// Create a new chat session. + /// + /// The executor for this session + public ChatSession(ILLamaExecutor executor) + { + // Check if executor has StatefulExecutorBase as base class + if (executor is not StatefulExecutorBase) + { + throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } - /// - /// - /// - /// The directory name to load the session. - public virtual void LoadSession(string path) + Executor = executor; + } + + /// + /// Create a new chat session with a custom history. + /// + /// + /// + public ChatSession(ILLamaExecutor executor, ChatHistory history) + : this(executor) + { + History = history; + } + + /// + /// Use a custom history transform. + /// + /// + /// + public ChatSession WithHistoryTransform(IHistoryTransform transform) + { + HistoryTransform = transform; + return this; + } + + /// + /// Add a text transform to the input transform pipeline. + /// + /// + /// + public ChatSession AddInputTransform(ITextTransform transform) + { + InputTransformPipeline.Add(transform); + return this; + } + + /// + /// Use a custom output transform. + /// + /// + /// + public ChatSession WithOutputTransform(ITextStreamTransform transform) + { + OutputTransform = transform; + return this; + } + + /// + /// Save a session from a directory. + /// + /// + /// + /// + public void SaveSession(string path) + { + if (string.IsNullOrWhiteSpace(path)) { - if (!Directory.Exists(path)) - { - throw new FileNotFoundException($"Directory {path} does not exist."); - } - _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); - if (Executor is StatelessExecutor) - { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } - } - else if (Executor is StatefulExecutorBase statefulExecutor) + if (Directory.Exists(path)) + { + Directory.Delete(path, recursive: true); + } + + Directory.CreateDirectory(path); + + string modelStateFilePath = Path.Combine(path, _modelStateFilename); + Executor.Context.SaveState(modelStateFilePath); + + string executorStateFilepath = Path.Combine(path, _executorStateFilename); + ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); + + string historyFilepath = Path.Combine(path, _hsitoryFilename); + File.WriteAllText(historyFilepath, History.ToJson()); + } + + /// + /// Load a session from a directory. + /// + /// + /// + /// + public void LoadSession(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (!Directory.Exists(path)) + { + throw new ArgumentException("Directory does not exist", nameof(path)); + } + + string modelStateFilePath = Path.Combine(path, _modelStateFilename); + Executor.Context.LoadState(modelStateFilePath); + + string executorStateFilepath = Path.Combine(path, _executorStateFilename); + ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); + + string historyFilepath = Path.Combine(path, _hsitoryFilename); + string historyJson = File.ReadAllText(historyFilepath); + History = ChatHistory.FromJson(historyJson) + ?? throw new ArgumentException("History file is invalid", nameof(path)); + } + + /// + /// Add a message to the chat history. + /// + /// + /// + public ChatSession AddMessage(ChatHistory.Message message) + { + // If current message is a system message, only allow the history to be empty + if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) + { + throw new ArgumentException("Cannot add a system message after another message", nameof(message)); + } + + // If current message is a user message, only allow the history to be empty, + // or the previous message to be a system message or assistant message. + if (message.AuthorRole == AuthorRole.User) + { + ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { - statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); + throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); } - else + } + + // If the current message is an assistant message, + // the previous message must be a user message. + if (message.AuthorRole == AuthorRole.Assistant) + { + ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + if (lastMessage is null + || lastMessage.AuthorRole != AuthorRole.User) { - throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); + throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message)); } } - /// - /// Generates a response for a given user prompt and manages history state for the user. - /// This will always pass the whole history to the model. Don't pass a whole history - /// to this method as the user prompt will be appended to the history of the current session. - /// If more control is needed, use the other overload of this method that accepts a ChatHistory object. - /// - /// - /// - /// - /// Returns generated text of the assistant message. - public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + History.AddMessage(message.AuthorRole, message.Content); + return this; + } + + /// + /// Add a system message to the chat history. + /// + /// + /// + public ChatSession AddSystemMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); + + /// + /// Add an assistant message to the chat history. + /// + /// + /// + public ChatSession AddAssistantMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + + /// + /// Add a user message to the chat history. + /// + /// + /// + public ChatSession AddUserMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); + + /// + /// Remove the last message from the chat history. + /// + /// + public ChatSession RemoveLastMessage() + { + History.Messages.RemoveAt(History.Messages.Count - 1); + return this; + } + + /// + /// Replace a user message with a new message and remove all messages after the new message. + /// This is useful when the user wants to edit a message. And regenerate the response. + /// + /// + /// + /// + public ChatSession ReplaceUserMessage( + ChatHistory.Message oldMessage, + ChatHistory.Message newMessage) + { + if (oldMessage.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("Old message must be a user message", nameof(oldMessage)); + } + + if (newMessage.AuthorRole != AuthorRole.User) { - foreach (var inputTransform in InputTransformPipeline) - prompt = inputTransform.Transform(prompt); + throw new ArgumentException("New message must be a user message", nameof(newMessage)); + } - History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); + int index = History.Messages.IndexOf(oldMessage); + if (index == -1) + { + throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); + } - if (_executor is InteractiveExecutor executor) - { - InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); - prompt = state.IsPromptRun - ? HistoryTransform.HistoryToText(History) - : HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt)); - } + History.Messages[index] = newMessage; + + // Remove all message after the new message + History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); + + return this; + } - StringBuilder sb = new(); + /// + /// Chat with the model. + /// + /// + /// + /// + /// + /// + /// + public async IAsyncEnumerable ChatAsync( + ChatHistory.Message message, + bool applyInputTransformPipeline, + IInferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // The message must be a user message + if (message.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("Message must be a user message", nameof(message)); + } - await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + // Apply input transform pipeline + if (applyInputTransformPipeline) + { + foreach (var inputTransform in InputTransformPipeline) { - yield return textToken; - sb.Append(textToken); + message.Content = inputTransform.Transform(message.Content); } + } - string assistantMessage = sb.ToString(); + // Add the user's message to the history + AddUserMessage(message.Content); + + // Prepare prompt variable + string prompt; + + // Check if the session history was restored from a previous session + // or added as part of new chat session history. + InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData(); + + // If "IsPromptRun" is true, the session was newly started. + if (state.IsPromptRun) + { + // If the session history was added as part of new chat session history, + // convert the complete history includsing system message and manually added history + // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. + prompt = HistoryTransform.HistoryToText(History); + } + else + { + // If the session was restored from a previous session, + // convert only the current message to the prompt with the prompt template + // specified in the HistoryTransform class implementation that is provided. + ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); + prompt = HistoryTransform.HistoryToText(singleMessageHistory); + } + + string assistantMessage = string.Empty; + + await foreach ( + string textToken + in ChatAsyncInternal( + prompt, + inferenceParams, + cancellationToken)) + { + assistantMessage += textToken; + yield return textToken; + } - // Remove end tokens from the assistant message - // if defined in inferenceParams.AntiPrompts. - // We only want the response that was generated and not tokens - // that are delimiting the beginning or end of the response. - if (inferenceParams?.AntiPrompts != null) + // Add the assistant message to the history + AddAssistantMessage(assistantMessage); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory.Message message, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + return ChatAsync( + message, + applyInputTransformPipeline: true, + inferenceParams, + cancellationToken); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory history, + bool applyInputTransformPipeline, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + ChatHistory.Message lastMessage = history.Messages.LastOrDefault() + ?? throw new ArgumentException("History must contain at least one message", nameof(history)); + + foreach ( + ChatHistory.Message message + in history.Messages.Take(history.Messages.Count - 1)) + { + // Apply input transform pipeline + if (applyInputTransformPipeline + && message.AuthorRole == AuthorRole.User) { - foreach (var stopToken in inferenceParams.AntiPrompts) + foreach ( + var inputTransform + in InputTransformPipeline) { - assistantMessage = assistantMessage.Replace(stopToken, "").Trim(); + message.Content = inputTransform.Transform(message.Content); } } - History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage)); + AddMessage(message); } - /// - /// Generates a response for a given chat history. This method does not manage history state for the user. - /// If you want to e.g. truncate the history of a session to fit into the model's context window, - /// use this method and pass the truncated history to it. If you don't need this control, use the other - /// overload of this method that accepts a user prompt instead. - /// - /// - /// - /// - /// Returns generated text of the assistant message. - public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + return ChatAsync( + lastMessage, + applyInputTransformPipeline, + inferenceParams, + cancellationToken); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory history, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + return ChatAsync( + history, + applyInputTransformPipeline: true, + inferenceParams, + cancellationToken); + } + + /// + /// Regenerate the last assistant message. + /// + /// + /// + /// + /// + public async IAsyncEnumerable RegenerateAssistantMessageAsync( + InferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Make sure the last message is an assistant message (reponse from the LLM). + ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); + + if (lastAssistantMessage is null + || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) { - if (history.Messages.Count == 0) - { - throw new ArgumentException("History must contain at least one message."); - } + throw new InvalidOperationException("Last message must be an assistant message"); + } - string prompt; - if (_executor is InteractiveExecutor executor) - { - InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); + // Remove the last assistant message from the history. + RemoveLastMessage(); - if (state.IsPromptRun) - { - prompt = HistoryTransform.HistoryToText(History); - } - else - { - ChatHistory.Message lastMessage = history.Messages.Last(); - prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); - } - } - else - { - ChatHistory.Message lastMessage = history.Messages.Last(); - prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); - } + // Get the last user message. + ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); - await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) - { - yield return textToken; - } + if (lastUserMessage is null + || lastUserMessage.AuthorRole != AuthorRole.User) + { + throw new InvalidOperationException("Last message must be a user message"); } - private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + // Remove the last user message from the history. + RemoveLastMessage(); + + // Regenerate the assistant message. + await foreach ( + string textToken + in ChatAsync( + lastUserMessage, + applyInputTransformPipeline: false, + inferenceParams, + cancellationToken)) { - var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); - await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) - { - yield return textToken; - } + yield return textToken; + } + } + + private async IAsyncEnumerable ChatAsyncInternal( + string prompt, + IInferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken); + + await foreach ( + string textToken + in OutputTransform + .TransformAsync(results) + .WithCancellation(cancellationToken)) + { + yield return textToken; } } -} \ No newline at end of file +} diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index 7224b314..3f038874 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,7 @@ using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using System.Text.Json.Serialization; namespace LLama.Common { @@ -43,11 +46,14 @@ namespace LLama.Common /// /// Role of the message author, e.g. user/assistant/system /// + [JsonConverter(typeof(JsonStringEnumConverter))] + [JsonPropertyName("author_role")] public AuthorRole AuthorRole { get; set; } /// /// Message content /// + [JsonPropertyName("content")] public string Content { get; set; } /// @@ -65,15 +71,14 @@ namespace LLama.Common /// /// List of messages in the chat /// - public List Messages { get; } + [JsonPropertyName("messages")] + public List Messages { get; set; } = new(); /// /// Create a new instance of the chat content class /// - public ChatHistory() - { - this.Messages = new List(); - } + [JsonConstructor] + public ChatHistory() { } /// /// Add a message to the chat history @@ -84,6 +89,29 @@ namespace LLama.Common { this.Messages.Add(new Message(authorRole, content)); } - } + /// + /// Serialize the chat history to JSON + /// + /// + public string ToJson() + { + return JsonSerializer.Serialize( + this, + new JsonSerializerOptions() + { + WriteIndented = true + }); + } + + /// + /// Deserialize a chat history from JSON + /// + /// + /// + public static ChatHistory? FromJson(string json) + { + return JsonSerializer.Deserialize(json); + } + } } From 73d17259542b082ce6c49f6adb3a1d892d8449ad Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Sat, 25 Nov 2023 09:29:00 -0600 Subject: [PATCH 07/22] Modified / updated ChatSession examples --- LLama.Examples/Assets/chat-with-bob.json | 24 +++++ .../Examples/ChatSessionStripRoleName.cs | 76 ++++++------- .../Examples/ChatSessionWithHistory.cs | 100 ++++++++++++++++++ .../Examples/ChatSessionWithRoleName.cs | 76 ++++++------- LLama.Examples/Examples/LoadAndSaveSession.cs | 13 ++- LLama.Examples/Examples/Runner.cs | 5 +- LLama.Examples/LLama.Examples.csproj | 3 + 7 files changed, 217 insertions(+), 80 deletions(-) create mode 100644 LLama.Examples/Assets/chat-with-bob.json create mode 100644 LLama.Examples/Examples/ChatSessionWithHistory.cs diff --git a/LLama.Examples/Assets/chat-with-bob.json b/LLama.Examples/Assets/chat-with-bob.json new file mode 100644 index 00000000..52dc3910 --- /dev/null +++ b/LLama.Examples/Assets/chat-with-bob.json @@ -0,0 +1,24 @@ +{ + "messages": [ + { + "author_role": "System", + "content": "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision." + }, + { + "author_role": "User", + "content": "Hello, Bob." + }, + { + "author_role": "Assistant", + "content": "Hello. How may I help you today?" + }, + { + "author_role": "User", + "content": "Please tell me the largest city in Europe." + }, + { + "author_role": "Assistant", + "content": "Sure. The largest city in Europe is Istanbul, Turkey." + } + ] +} diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index 41362c4a..b39ac3ef 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -1,44 +1,44 @@ -using LLama.Common; +// using LLama.Common; -namespace LLama.Examples.Examples -{ - public class ChatSessionStripRoleName - { - public static async Task Run() - { - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); +// namespace LLama.Examples.Examples +// { +// public class ChatSessionStripRoleName +// { +// public static async Task Run() +// { +// Console.Write("Please input your model path: "); +// var modelPath = Console.ReadLine(); +// var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - var parameters = new ModelParams(modelPath) - { - ContextSize = 1024, - Seed = 1337, - GpuLayerCount = 5 - }; - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var executor = new InteractiveExecutor(context); +// var parameters = new ModelParams(modelPath) +// { +// ContextSize = 1024, +// Seed = 1337, +// GpuLayerCount = 5 +// }; +// using var model = LLamaWeights.LoadFromFile(parameters); +// using var context = model.CreateContext(parameters); +// var executor = new InteractiveExecutor(context); - var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); +// var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started. The role names won't be printed."); - Console.ForegroundColor = ConsoleColor.White; +// Console.ForegroundColor = ConsoleColor.Yellow; +// Console.WriteLine("The chat session has started. The role names won't be printed."); +// Console.ForegroundColor = ConsoleColor.White; - // show the prompt - Console.Write(prompt); - while (true) - { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) - { - Console.Write(text); - } +// // show the prompt +// Console.Write(prompt); +// while (true) +// { +// await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) +// { +// Console.Write(text); +// } - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; - } - } - } -} +// Console.ForegroundColor = ConsoleColor.Green; +// prompt = Console.ReadLine(); +// Console.ForegroundColor = ConsoleColor.White; +// } +// } +// } +// } diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs new file mode 100644 index 00000000..27f4912b --- /dev/null +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -0,0 +1,100 @@ +using DocumentFormat.OpenXml.Bibliography; +using LLama.Common; + +namespace LLama.Examples.Examples; + +public class ChatSessionWithHistory +{ + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + ChatSession session; + if (Directory.Exists("Assets/chat-with-bob")) + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Loading session from disk."); + Console.ForegroundColor = ConsoleColor.White; + + session = new ChatSession(executor); + session.LoadSession("Assets/chat-with-bob"); + } + else + { + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + session = new ChatSession(executor, chatHistory); + } + + session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:" }, + redundancyLength: 8)); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + if (userInput == "save") + { + session.SaveSession("Assets/chat-with-bob"); + // await session.LoadSessionAsync("Assets/chat-with-bob"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session saved."); + } + else if (userInput == "regenerate") + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Regenerating last response ..."); + + await foreach ( + var text + in session.RegenerateAssistantMessageAsync( + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + else + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } +} diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index c9ea9023..e5c180b7 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -1,44 +1,44 @@ -using LLama.Common; +// using LLama.Common; -namespace LLama.Examples.Examples -{ - public class ChatSessionWithRoleName - { - public static async Task Run() - { - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); +// namespace LLama.Examples.Examples +// { +// public class ChatSessionWithRoleName +// { +// public static async Task Run() +// { +// Console.Write("Please input your model path: "); +// var modelPath = Console.ReadLine(); +// var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - var parameters = new ModelParams(modelPath) - { - ContextSize = 1024, - Seed = 1337, - GpuLayerCount = 5 - }; - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var executor = new InteractiveExecutor(context); +// var parameters = new ModelParams(modelPath) +// { +// ContextSize = 1024, +// Seed = 1337, +// GpuLayerCount = 5 +// }; +// using var model = LLamaWeights.LoadFromFile(parameters); +// using var context = model.CreateContext(parameters); +// var executor = new InteractiveExecutor(context); - var session = new ChatSession(executor); +// var session = new ChatSession(executor); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); - Console.ForegroundColor = ConsoleColor.White; +// Console.ForegroundColor = ConsoleColor.Yellow; +// Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); +// Console.ForegroundColor = ConsoleColor.White; - // show the prompt - Console.Write(prompt); - while (true) - { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) - { - Console.Write(text); - } +// // show the prompt +// Console.Write(prompt); +// while (true) +// { +// await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) +// { +// Console.Write(text); +// } - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; - } - } - } -} +// Console.ForegroundColor = ConsoleColor.Green; +// prompt = Console.ReadLine(); +// Console.ForegroundColor = ConsoleColor.White; +// } +// } +// } +// } diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs index 91068091..678d3eb9 100644 --- a/LLama.Examples/Examples/LoadAndSaveSession.cs +++ b/LLama.Examples/Examples/LoadAndSaveSession.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using DocumentFormat.OpenXml.Bibliography; +using LLama.Common; namespace LLama.Examples.Examples { @@ -30,7 +31,15 @@ namespace LLama.Examples.Examples Console.Write(prompt); while (true) { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, prompt), + new InferenceParams() + { + Temperature = 0.6f, + AntiPrompts = new List { "User:" } + })) { Console.Write(text); } diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index aca0a7da..43c12f87 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -6,8 +6,9 @@ public class Runner { private static readonly Dictionary> Examples = new() { - { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, - { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, + { "Run a chat session with history.", ChatSessionWithHistory.Run }, + // { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, + // { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, { "Interactive mode chat by using executor.", InteractiveModeExecute.Run }, { "Instruct mode chat by using executor.", InstructModeExecute.Run }, { "Stateless mode chat by using executor.", StatelessModeExecute.Run }, diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index 5053c038..c2491218 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -41,6 +41,9 @@ + + PreserveNewest + PreserveNewest From 422605d98063ef5da2cef8f33819b2d35c5c43df Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Mon, 27 Nov 2023 08:26:52 -0600 Subject: [PATCH 08/22] Re-add ChatSession examples --- .../Examples/ChatSessionStripRoleName.cs | 105 ++++++++++-------- .../Examples/ChatSessionWithHistory.cs | 2 - .../Examples/ChatSessionWithRoleName.cs | 102 +++++++++-------- LLama.Examples/Examples/Runner.cs | 4 +- 4 files changed, 121 insertions(+), 92 deletions(-) diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index b39ac3ef..1246db59 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -1,44 +1,61 @@ -// using LLama.Common; - -// namespace LLama.Examples.Examples -// { -// public class ChatSessionStripRoleName -// { -// public static async Task Run() -// { -// Console.Write("Please input your model path: "); -// var modelPath = Console.ReadLine(); -// var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - -// var parameters = new ModelParams(modelPath) -// { -// ContextSize = 1024, -// Seed = 1337, -// GpuLayerCount = 5 -// }; -// using var model = LLamaWeights.LoadFromFile(parameters); -// using var context = model.CreateContext(parameters); -// var executor = new InteractiveExecutor(context); - -// var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); - -// Console.ForegroundColor = ConsoleColor.Yellow; -// Console.WriteLine("The chat session has started. The role names won't be printed."); -// Console.ForegroundColor = ConsoleColor.White; - -// // show the prompt -// Console.Write(prompt); -// while (true) -// { -// await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) -// { -// Console.Write(text); -// } - -// Console.ForegroundColor = ConsoleColor.Green; -// prompt = Console.ReadLine(); -// Console.ForegroundColor = ConsoleColor.White; -// } -// } -// } -// } +using LLama.Common; + +namespace LLama.Examples.Examples; + +public class ChatSessionStripRoleName +{ + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + ChatSession session = new(executor, chatHistory); + session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:" }, + redundancyLength: 8)); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } +} diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 27f4912b..98ba7d75 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -1,4 +1,3 @@ -using DocumentFormat.OpenXml.Bibliography; using LLama.Common; namespace LLama.Examples.Examples; @@ -60,7 +59,6 @@ public class ChatSessionWithHistory if (userInput == "save") { session.SaveSession("Assets/chat-with-bob"); - // await session.LoadSessionAsync("Assets/chat-with-bob"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index e5c180b7..d6b0d98e 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -1,44 +1,58 @@ -// using LLama.Common; - -// namespace LLama.Examples.Examples -// { -// public class ChatSessionWithRoleName -// { -// public static async Task Run() -// { -// Console.Write("Please input your model path: "); -// var modelPath = Console.ReadLine(); -// var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - -// var parameters = new ModelParams(modelPath) -// { -// ContextSize = 1024, -// Seed = 1337, -// GpuLayerCount = 5 -// }; -// using var model = LLamaWeights.LoadFromFile(parameters); -// using var context = model.CreateContext(parameters); -// var executor = new InteractiveExecutor(context); - -// var session = new ChatSession(executor); - -// Console.ForegroundColor = ConsoleColor.Yellow; -// Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); -// Console.ForegroundColor = ConsoleColor.White; - -// // show the prompt -// Console.Write(prompt); -// while (true) -// { -// await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) -// { -// Console.Write(text); -// } - -// Console.ForegroundColor = ConsoleColor.Green; -// prompt = Console.ReadLine(); -// Console.ForegroundColor = ConsoleColor.White; -// } -// } -// } -// } +using LLama.Common; + +namespace LLama.Examples.Examples; + +public class ChatSessionWithRoleName +{ + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + ChatSession session = new(executor, chatHistory); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } +} diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index 43c12f87..d7653657 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -7,8 +7,8 @@ public class Runner private static readonly Dictionary> Examples = new() { { "Run a chat session with history.", ChatSessionWithHistory.Run }, - // { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, - // { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, + { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, + { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, { "Interactive mode chat by using executor.", InteractiveModeExecute.Run }, { "Instruct mode chat by using executor.", InstructModeExecute.Run }, { "Stateless mode chat by using executor.", StatelessModeExecute.Run }, From b34f72a883a8851cafb6fb6e3ebca9fa2c0e3a29 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 01:02:27 +0000 Subject: [PATCH 09/22] - Added `SamplingPipeline` to inference params which overrides all other options with an entirely custom pipeline. - Added a `Sample` method to `LLamaContext` which uses a custom pipeline - Modified all executors to use the custom pipeline if it exists --- LLama.Web/Common/InferenceOptions.cs | 10 ++++-- LLama/Abstractions/IInferenceParams.cs | 6 ++++ LLama/Common/InferenceParams.cs | 4 +++ LLama/LLamaContext.cs | 12 +++++++ LLama/LLamaInstructExecutor.cs | 26 ++++++++++------ LLama/LLamaInteractExecutor.cs | 28 +++++++++++------ LLama/LLamaStatelessExecutor.cs | 29 +++++++++++------ LLama/Sampling/ISamplingPipeline.cs | 43 +++++++++++++++++++++++--- 8 files changed, 123 insertions(+), 35 deletions(-) diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 89d94ade..c604dc0d 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -1,6 +1,9 @@ -using LLama.Common; +#nullable enable + +using LLama.Common; using LLama.Abstractions; using LLama.Native; +using LLama.Sampling; namespace LLama.Web.Common { @@ -64,6 +67,9 @@ namespace LLama.Web.Common /// /// A grammar to constrain possible tokens /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index d87faf0e..e1e89414 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using LLama.Common; using LLama.Native; +using LLama.Sampling; namespace LLama.Abstractions { @@ -108,5 +109,10 @@ namespace LLama.Abstractions /// Grammar to constrain possible tokens /// SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! + /// + ISamplingPipeline? SamplingPipeline { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index d7bd19d9..c1f39550 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using LLama.Native; +using LLama.Sampling; namespace LLama.Common { @@ -76,6 +77,9 @@ namespace LLama.Common /// public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } /// diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 3a3e51af..2902dc8f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -10,6 +10,7 @@ using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; using LLama.Abstractions; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -212,6 +213,17 @@ namespace LLama } } + /// + /// Sample a single token from this context, using the given sampling pipeline + /// + /// The pipeline to use to process the logits and to select a token + /// The tokens recently returned from the model + /// The selected token + public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) + { + return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + } + /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d81630aa..3ed66890 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -210,16 +210,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4d28274b..9cecf437 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -189,16 +189,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 9c41af7c..831aceb2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using LLama.Native; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -85,16 +86,24 @@ namespace LLama var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - // Sample a single token - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + } + else + { + // Penalize the generated tokens by various penalties + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + // Sample a single token + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + } // Decode this token into text decoder.Add(id); diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 489f2c5a..4540e9fc 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -1,5 +1,7 @@ using System; +using System.Buffers; using System.Collections.Generic; +using System.Runtime.InteropServices; using LLama.Native; using LLama.Sampling.Logits; using LLama.Sampling.Selection; @@ -16,9 +18,9 @@ public interface ISamplingPipeline /// /// Sample a single token from the given logits /// - /// - /// - /// + /// The context being sampled from + /// The logits produced by the model + /// A span of tokens recently returned by the model /// int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); @@ -28,10 +30,43 @@ public interface ISamplingPipeline void Reset(); } +/// +/// Extensions methods for ISamplingPipeline +/// +public static class ISamplingPipelineExtensions +{ + /// + /// Sample a single token from the given logits + /// + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + /// + public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(lastTokens); + return pipeline.Sample(ctx, logits, span); +#else + var copy = ArrayPool.Shared.Rent(lastTokens.Count); + try + { + lastTokens.CopyTo(copy); + return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); + } + finally + { + ArrayPool.Shared.Return(copy); + } +#endif + } +} + /// /// Simple implementation of `ISamplingPipeline`, applies processors in order every time /// -public sealed class BasicSamplingPipeline +public sealed class ConfigurableSamplingPipeline : ISamplingPipeline { /// From 3afc007499866f5b47f98993d46a6fcb5b4f8fd2 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 01:17:24 +0000 Subject: [PATCH 10/22] - Added "protected" logits, instead of the awkward save/load mechanism - Added an example usage to one of the tests --- LLama.Unittest/StatelessExecutorTest.cs | 37 ++++++++- LLama/Sampling/ISamplingPipeline.cs | 33 +++++++- LLama/Sampling/Logits/SaveLoad.cs | 100 ------------------------ 3 files changed, 66 insertions(+), 104 deletions(-) delete mode 100644 LLama/Sampling/Logits/SaveLoad.cs diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 195cc4a2..d847e787 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,5 +1,9 @@ using System.Diagnostics; using LLama.Common; +using LLama.Sampling; +using LLama.Sampling.Logits; +using LLama.Sampling.Selection; +using LLama.Sampling.Tokens; using Xunit.Abstractions; namespace LLama.Unittest @@ -30,10 +34,41 @@ namespace LLama.Unittest [Fact] public async Task Stateless() { + // Create a custom pipeline that mimics the default pipeline + var pipeline = new ConfigurableSamplingPipeline() + { + ProtectedLogits = + { + _weights.NewlineToken, + _weights.BeginningOfSentenceToken, + _weights.EndOfSentenceToken + }, + LogitProcessors = + { + new LogitBias + { + Biases = + { + { _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing! + { 42, 0f }, + } + } + }, + TokenDataProcessors = + { + new TailFreeSampling { Z = 1 }, + new LocallyTypicalSampling { P = 1 }, + new TopPSampling { P = 0.95f }, + new MinPSampling { P = 0.05f }, + new TemperatureSampling { Temperature = 0.8f }, + }, + Selector = new StandardSelection(), + }; + var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; - var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline}; var timer = new Stopwatch(); timer.Start(); diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 4540e9fc..3b829ed4 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -74,6 +74,11 @@ public sealed class ConfigurableSamplingPipeline /// public IList LogitProcessors { get; } = new List(); + /// + /// Logits values which will not be changed by the logit processors + /// + public IList ProtectedLogits { get; } = new List(); + /// /// Token data processors to apply in this pipeline /// @@ -87,9 +92,31 @@ public sealed class ConfigurableSamplingPipeline /// public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - // Modify raw logits - foreach (var logitProcessor in LogitProcessors) - logitProcessor.ProcessLogits(ctx, logits, lastTokens); + var savedLogitsCount = ProtectedLogits.Count; + var savedLogitValues = ArrayPool.Shared.Rent(savedLogitsCount); + var savedLogitIndices = ArrayPool.Shared.Rent(savedLogitsCount); + try + { + // Save the values of protected logits + for (var i = 0; i < ProtectedLogits.Count; i++) + { + savedLogitValues[i] = logits[ProtectedLogits[i]]; + savedLogitIndices[i] = ProtectedLogits[i]; + } + + // Modify raw logits + foreach (var logitProcessor in LogitProcessors) + logitProcessor.ProcessLogits(ctx, logits, lastTokens); + + // Restore the values of protected logits + for (var i = 0; i < savedLogitsCount; i++) + logits[savedLogitIndices[i]] = savedLogitValues[i]; + } + finally + { + ArrayPool.Shared.Return(savedLogitValues); + ArrayPool.Shared.Return(savedLogitIndices); + } // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); diff --git a/LLama/Sampling/Logits/SaveLoad.cs b/LLama/Sampling/Logits/SaveLoad.cs deleted file mode 100644 index 6f80aec4..00000000 --- a/LLama/Sampling/Logits/SaveLoad.cs +++ /dev/null @@ -1,100 +0,0 @@ -using System; -using System.Collections.Generic; -using LLama.Native; - -namespace LLama.Sampling.Logits; - -/// -/// Save certain logit values -/// -public sealed class SaveLogitValues - : ILogitProcessor -{ - private readonly Dictionary _saved = new(); - - /// - /// Logits to save - /// - public ISet Logits { get; } = new HashSet(); - - /// - /// Saved logit values - /// - public IReadOnlyDictionary Values => _saved; - - /// - public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - _saved.Clear(); - foreach (var logit in Logits) - _saved[logit] = logits[logit]; - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - _saved.Clear(); - } - - /// - public void Dispose() - { - } - - /// - /// Get a logit processor that overwrite the logit values with the values saved here - /// - /// - public ILogitProcessor GetWriter() - { - return new LoadLogitValues(_saved); - } -} - -/// -/// Overwrite certain logit values -/// -public sealed class LoadLogitValues - : ILogitProcessor -{ - /// - /// Logits to overwrite, token -> logit - /// - public IDictionary Values { get; } - - /// - /// Create a new LoadLogitValues - /// - /// Source for values to overwrite - public LoadLogitValues(Dictionary? values = null) - { - Values = values ?? new Dictionary(); - } - - /// - public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - foreach (var logit in Values) - logits[logit.Key] = logit.Value; - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file From 835958398cc6c5948036f269796810f20bf6657a Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 16:25:13 +0000 Subject: [PATCH 11/22] - Removed the object wrappers and configurable pipeline, they can be better written in code. - Added BaseSamplingPipeline which provides a base impl of `ISamplingPipeline` - Added `DefaultSamplingPipeline` which mimics normal llama.cpp sampling --- LLama.Unittest/GrammarParserTest.cs | 3 +- LLama.Unittest/StatelessExecutorTest.cs | 35 +--- LLama/Native/LLamaTokenDataArray.cs | 29 +++- LLama/Sampling/BaseSamplingPipeline.cs | 128 +++++++++++++++ LLama/Sampling/DefaultSamplingPipeline.cs | 149 ++++++++++++++++++ LLama/Sampling/ISamplingPipeline.cs | 102 +----------- LLama/Sampling/Logits/ILogitProcessor.cs | 34 ---- LLama/Sampling/Logits/LogitBias.cs | 39 ----- LLama/Sampling/Selection/GreedySelection.cs | 27 ---- LLama/Sampling/Selection/ITokenSelector.cs | 25 --- .../Sampling/Selection/Mirostat2Selection.cs | 65 -------- LLama/Sampling/Selection/MirostatSelection.cs | 76 --------- LLama/Sampling/Selection/StandardSelection.cs | 27 ---- LLama/Sampling/Tokens/GrammarSampling.cs | 59 ------- LLama/Sampling/Tokens/ITokenDataProcessor.cs | 34 ---- .../Sampling/Tokens/LocallyTypicalSampling.cs | 42 ----- LLama/Sampling/Tokens/MinPSampling.cs | 42 ----- LLama/Sampling/Tokens/RepetitionPenalty.cs | 77 --------- LLama/Sampling/Tokens/TailFreeSampling.cs | 42 ----- LLama/Sampling/Tokens/TemperatureSampling.cs | 38 ----- LLama/Sampling/Tokens/TopKSampling.cs | 38 ----- LLama/Sampling/Tokens/TopPSampling.cs | 42 ----- 22 files changed, 309 insertions(+), 844 deletions(-) create mode 100644 LLama/Sampling/BaseSamplingPipeline.cs create mode 100644 LLama/Sampling/DefaultSamplingPipeline.cs delete mode 100644 LLama/Sampling/Logits/ILogitProcessor.cs delete mode 100644 LLama/Sampling/Logits/LogitBias.cs delete mode 100644 LLama/Sampling/Selection/GreedySelection.cs delete mode 100644 LLama/Sampling/Selection/ITokenSelector.cs delete mode 100644 LLama/Sampling/Selection/Mirostat2Selection.cs delete mode 100644 LLama/Sampling/Selection/MirostatSelection.cs delete mode 100644 LLama/Sampling/Selection/StandardSelection.cs delete mode 100644 LLama/Sampling/Tokens/GrammarSampling.cs delete mode 100644 LLama/Sampling/Tokens/ITokenDataProcessor.cs delete mode 100644 LLama/Sampling/Tokens/LocallyTypicalSampling.cs delete mode 100644 LLama/Sampling/Tokens/MinPSampling.cs delete mode 100644 LLama/Sampling/Tokens/RepetitionPenalty.cs delete mode 100644 LLama/Sampling/Tokens/TailFreeSampling.cs delete mode 100644 LLama/Sampling/Tokens/TemperatureSampling.cs delete mode 100644 LLama/Sampling/Tokens/TopKSampling.cs delete mode 100644 LLama/Sampling/Tokens/TopPSampling.cs diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 9ad77531..389563aa 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -1,5 +1,4 @@ -using System.Text; -using LLama.Exceptions; +using LLama.Exceptions; using LLama.Native; using LLama.Grammars; diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index d847e787..72e9acf8 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,9 +1,6 @@ using System.Diagnostics; using LLama.Common; using LLama.Sampling; -using LLama.Sampling.Logits; -using LLama.Sampling.Selection; -using LLama.Sampling.Tokens; using Xunit.Abstractions; namespace LLama.Unittest @@ -35,40 +32,12 @@ namespace LLama.Unittest public async Task Stateless() { // Create a custom pipeline that mimics the default pipeline - var pipeline = new ConfigurableSamplingPipeline() - { - ProtectedLogits = - { - _weights.NewlineToken, - _weights.BeginningOfSentenceToken, - _weights.EndOfSentenceToken - }, - LogitProcessors = - { - new LogitBias - { - Biases = - { - { _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing! - { 42, 0f }, - } - } - }, - TokenDataProcessors = - { - new TailFreeSampling { Z = 1 }, - new LocallyTypicalSampling { P = 1 }, - new TopPSampling { P = 0.95f }, - new MinPSampling { P = 0.05f }, - new TemperatureSampling { Temperature = 0.8f }, - }, - Selector = new StandardSelection(), - }; + var pipeline = new DefaultSamplingPipeline(); var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; - var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline}; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; var timer = new Stopwatch(); timer.Start(); diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 897cf8b8..5059a5f3 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -46,14 +46,41 @@ namespace LLama.Native return new LLamaTokenDataArray(candidates); } + /// + /// Overwrite the logit values for all given tokens + /// + /// tuples of token and logit value to overwrite + public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) + { + if (values.Length == 0) + return; + + var dataSpan = data.Span; + foreach (var (token, value) in values) + { + for (var i = 0; i < data.Length; i++) + { + if (dataSpan[i].id == token) + { + dataSpan[i].logit = value; + break; + } + } + } + sorted = false; + } + #region sampling /// /// Apply grammar rules to candidate tokens /// /// /// - public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) + public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) { + if (grammar == null) + return; + using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_grammar(ctx, ref st, grammar); diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs new file mode 100644 index 00000000..4c0f7689 --- /dev/null +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -0,0 +1,128 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. +/// +public abstract class BaseSamplingPipeline + : ISamplingPipeline +{ + private int _savedLogitsCount; + private (int index, float logit)[]? _savedLogits; + + /// + public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + var protectedLogits = GetProtectedTokens(ctx); + _savedLogitsCount = protectedLogits.Count; + _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); + try + { + // Save the values of protected logits + for (var i = 0; i < protectedLogits.Count; i++) + { + var index = protectedLogits[i]; + var value = logits[index]; + _savedLogits[i] = (index, value); + } + + // Process raw logits + ProcessLogits(ctx, logits, lastTokens); + + // Automatically restore saved logit values after processing + RestoreProtectedTokens(logits); + + // Convert logits into token candidates + var candidates = LLamaTokenDataArray.Create(logits); + + // Process token data array + ProcessTokenDataArray(ctx, candidates, lastTokens); + + // Choose the final value + return ChooseToken(ctx, candidates); + } + finally + { + ArrayPool<(int, float)>.Shared.Return(_savedLogits); + _savedLogits = null; + _savedLogitsCount = 0; + } + } + + #region protected tokens + /// + /// Get all of the "protected" tokens that cannot be changed by ProcessLogits + /// + /// + protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(Span logits) + { + if (_savedLogits == null) + return; + + // The array may be bigger than necessary, get a span of the valid bit + var saved = _savedLogits.AsSpan(0, _savedLogitsCount); + + // Restore the values of protected logits + for (var i = 0; i < saved.Length; i++) + logits[saved[i].index] = saved[i].logit; + } + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) + { + if (_savedLogits == null || _savedLogits.Length == 0) + return; + + candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); + } + #endregion + + /// + /// Process the raw logit values + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Process the LLamaTokenDataArray and select a single token + /// + /// The context being sampled from + /// The LLamaTokenDataArray data produced by the model + /// A list of tokens recently returned by the model + /// + protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + + /// + /// Choose the final token from the candidates + /// + /// + /// + /// + protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); + + /// + public virtual void Reset() + { + } + + /// + public virtual void Dispose() + { + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs new file mode 100644 index 00000000..e6db2efe --- /dev/null +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using LLama.Extensions; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling +/// +public sealed class DefaultSamplingPipeline + : BaseSamplingPipeline +{ + /// + /// Bias values to add to certain logits + /// + public Dictionary LogitBias { get; } = new(); + + /// + /// Grammar to constrain valid tokens + /// + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 + /// + public float RepeatPenalty { get; set; } = 1.1f; + + /// + /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text + /// so far, decreasing the model's likelihood to repeat the same line verbatim. + ///
+ public float AlphaFrequency + { + get => _alphaFreq; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaFreq = value; + } + } + private float _alphaFreq = 0.1f; + + /// + /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + /// text so far, increasing the model's likelihood to talk about new topics. + ///
+ public float AlphaPresence + { + get => _alphaPresence; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaPresence = value; + } + } + private float _alphaPresence = 0.1f; + + /// + /// Temperature to apply (higher temperature is more "creative") + /// + public float Temperature { get; set; } = 0.75f; + + /// + /// Number of tokens to keep in TopK sampling + /// + public int TopK { get; set; } + + /// + /// Z value for tail free sampling + /// + public float TailFreeZ { get; set; } + + /// + /// P value for locally typical sampling + /// + public float TypicalP { get; set; } + + /// + /// P value for TopP sampling + /// + public float TopP { get; set; } = 1f; + + /// + /// P value for MinP sampling + /// + public float MinP { get; set; } + + /// + /// Whether the newline value should be protected from being modified by logit bias and repeat penalty + /// + public bool PenalizeNewline { get; set; } = false; + + private readonly int[] _newlineToken = new int[1]; + + /// + protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + { + if (PenalizeNewline) + return Array.Empty(); + + _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); + return _newlineToken; + } + + /// + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var (key, value) in LogitBias) + logits[key] += value; + } + + /// + protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + // Apply penalties to candidates + candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + + // Restore protected tokens, so they are not affected by repetition penalties + RestoreProtectedTokens(candidates); + + // Apply the normal llama.cpp pipeline + candidates.ApplyGrammar(ctx, Grammar); + candidates.TopK(ctx, TopK); + candidates.TailFree(ctx, TailFreeZ); + candidates.LocallyTypical(ctx, TypicalP); + candidates.TopP(ctx, TopP); + candidates.MinP(ctx, MinP); + candidates.Temperature(ctx, Temperature); + var id = candidates.SampleToken(ctx); + + Grammar?.AcceptToken(ctx, id); + return id; + } + + /// + protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + { + return candidates.SampleToken(ctx); + } +} \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 3b829ed4..f39bf996 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -3,14 +3,11 @@ using System.Buffers; using System.Collections.Generic; using System.Runtime.InteropServices; using LLama.Native; -using LLama.Sampling.Logits; -using LLama.Sampling.Selection; -using LLama.Sampling.Tokens; namespace LLama.Sampling; /// -/// Convert a span of logits into a single sampled token +/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. /// public interface ISamplingPipeline : IDisposable @@ -61,101 +58,4 @@ public static class ISamplingPipelineExtensions } #endif } -} - -/// -/// Simple implementation of `ISamplingPipeline`, applies processors in order every time -/// -public sealed class ConfigurableSamplingPipeline - : ISamplingPipeline -{ - /// - /// Logit processors to apply in this pipeline - /// - public IList LogitProcessors { get; } = new List(); - - /// - /// Logits values which will not be changed by the logit processors - /// - public IList ProtectedLogits { get; } = new List(); - - /// - /// Token data processors to apply in this pipeline - /// - public IList TokenDataProcessors { get; } = new List(); - - /// - /// The selector to choose the final token - /// - public ITokenSelector Selector { get; set; } = new StandardSelection(); - - /// - public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - var savedLogitsCount = ProtectedLogits.Count; - var savedLogitValues = ArrayPool.Shared.Rent(savedLogitsCount); - var savedLogitIndices = ArrayPool.Shared.Rent(savedLogitsCount); - try - { - // Save the values of protected logits - for (var i = 0; i < ProtectedLogits.Count; i++) - { - savedLogitValues[i] = logits[ProtectedLogits[i]]; - savedLogitIndices[i] = ProtectedLogits[i]; - } - - // Modify raw logits - foreach (var logitProcessor in LogitProcessors) - logitProcessor.ProcessLogits(ctx, logits, lastTokens); - - // Restore the values of protected logits - for (var i = 0; i < savedLogitsCount; i++) - logits[savedLogitIndices[i]] = savedLogitValues[i]; - } - finally - { - ArrayPool.Shared.Return(savedLogitValues); - ArrayPool.Shared.Return(savedLogitIndices); - } - - // Convert logits into token candidates - var candidates_p = LLamaTokenDataArray.Create(logits); - - // Process token candidates - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens); - - // Select a token - var token = Selector.Select(ctx, candidates_p, lastTokens); - - // Tell processors what was selected - foreach (var logitProcessor in LogitProcessors) - logitProcessor.AcceptToken(ctx, token); - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.AcceptToken(ctx, token); - - return token; - } - - /// - public void Reset() - { - foreach (var logitProcessor in LogitProcessors) - logitProcessor.Reset(); - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.Reset(); - - Selector.Reset(); - } - - /// - public void Dispose() - { - foreach (var logitProcessor in LogitProcessors) - logitProcessor.Dispose(); - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.Dispose(); - - Selector.Dispose(); - } } \ No newline at end of file diff --git a/LLama/Sampling/Logits/ILogitProcessor.cs b/LLama/Sampling/Logits/ILogitProcessor.cs deleted file mode 100644 index 76968499..00000000 --- a/LLama/Sampling/Logits/ILogitProcessor.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Logits; - -using llama_token = Int32; - -/// -/// Processes raw logits before sampling, applying penalties to certain tokens -/// -public interface ILogitProcessor - : IDisposable -{ - /// - /// Process raw logits, indexed by llama_token - /// - /// The context this is operating in - /// The token data array to process - /// The most recent tokens output - /// LLamaTokenDataArray, created from logits - void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); - - /// - /// Inform this process when a token is accepted by the model - /// - /// - /// - void AcceptToken(SafeLLamaContextHandle ctx, int token); - - /// - /// Reset all internal sampling state - /// - void Reset(); -} \ No newline at end of file diff --git a/LLama/Sampling/Logits/LogitBias.cs b/LLama/Sampling/Logits/LogitBias.cs deleted file mode 100644 index fc821508..00000000 --- a/LLama/Sampling/Logits/LogitBias.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.Collections.Generic; -using LLama.Native; - -namespace LLama.Sampling.Logits; - -/// -/// Add a bias directly to logit values -/// -public sealed class LogitBias - : ILogitProcessor -{ - /// - /// Biases to apply, token -> bias - /// - public IDictionary Biases { get; } = new Dictionary(); - - /// - public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - foreach (var kvp in Biases) - logits[kvp.Key] += kvp.Value; - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/GreedySelection.cs b/LLama/Sampling/Selection/GreedySelection.cs deleted file mode 100644 index 30b72456..00000000 --- a/LLama/Sampling/Selection/GreedySelection.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select the most likely token -/// -public sealed class GreedySelection - : ITokenSelector -{ - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleTokenGreedy(ctx); - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/ITokenSelector.cs b/LLama/Sampling/Selection/ITokenSelector.cs deleted file mode 100644 index c8915a92..00000000 --- a/LLama/Sampling/Selection/ITokenSelector.cs +++ /dev/null @@ -1,25 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select a single token from a set of possibilities -/// -public interface ITokenSelector - : IDisposable -{ - /// - /// Select a single token - /// - /// - /// - /// - /// - int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); - - /// - /// Reset the state - /// - void Reset(); -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/Mirostat2Selection.cs b/LLama/Sampling/Selection/Mirostat2Selection.cs deleted file mode 100644 index cdc802c1..00000000 --- a/LLama/Sampling/Selection/Mirostat2Selection.cs +++ /dev/null @@ -1,65 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select a token using Mirostat sampling. -/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. -/// -public sealed class Mirostat2Selection - : ITokenSelector -{ - private float _mu; - - /// - /// Current value of Mu, updated based on the difference between target surprise and actual surprise - /// - public float Mu - { - get => _mu; - set => _mu = value; - } - - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// - public float Tau { get; set; } - - /// - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// - public float Eta { get; set; } - - /// - /// Create a new Mirostat 2.0 sampler - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - public Mirostat2Selection(float tau, float eta) - { - Tau = tau; - Eta = eta; - } - - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); - } - - /// - public void Reset() - { - _mu = 2 * Tau; - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/MirostatSelection.cs b/LLama/Sampling/Selection/MirostatSelection.cs deleted file mode 100644 index 5ec34a13..00000000 --- a/LLama/Sampling/Selection/MirostatSelection.cs +++ /dev/null @@ -1,76 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select a token using Mirostat sampling. -/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. -/// -public sealed class MirostatSelection - : ITokenSelector -{ - private float _mu; - - /// - /// Current value of Mu, updated based on the difference between target surprise and actual surprise - /// - public float Mu - { - get => _mu; - set => _mu = value; - } - - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// - public float Tau { get; set; } - - /// - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// - public float Eta { get; set; } - - /// - /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn - /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects - /// the performance of the algorithm. - /// - public int M { get; set; } - - /// - /// Create a new Mirostat 2.0 sampler - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn - /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects - /// the performance of the algorithm. - public MirostatSelection(float tau, float eta, int m = 100) - { - Tau = tau; - Eta = eta; - M = m; - } - - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu); - } - - /// - public void Reset() - { - _mu = 2 * Tau; - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/StandardSelection.cs b/LLama/Sampling/Selection/StandardSelection.cs deleted file mode 100644 index 3e3bd086..00000000 --- a/LLama/Sampling/Selection/StandardSelection.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select from all possible tokens according to their probability -/// -public sealed class StandardSelection - : ITokenSelector -{ - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleToken(ctx); - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/GrammarSampling.cs b/LLama/Sampling/Tokens/GrammarSampling.cs deleted file mode 100644 index b823a7f9..00000000 --- a/LLama/Sampling/Tokens/GrammarSampling.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using LLama.Grammars; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Apply a grammar to prevent sampling tokens which do not match the grammar -/// -public sealed class GrammarSampling - : ITokenDataProcessor -{ - private SafeLLamaGrammarHandle? _handle; - - /// - /// Grammar to use for sampling - /// - public Grammar? Grammar { get; set; } - - /// - /// Create a new - /// - /// - public GrammarSampling(Grammar grammar) - { - Grammar = grammar; - } - - /// - public void Reset() - { - _handle?.Dispose(); - _handle = null; - } - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - // Create a new grammar instance if necessary - _handle ??= Grammar?.CreateInstance(); - - // Apply it - if (_handle != null) - tokens.ApplyGrammar(ctx, _handle); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - _handle?.AcceptToken(ctx, token); - } - - /// - public void Dispose() - { - _handle?.Dispose(); - _handle = null; - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/ITokenDataProcessor.cs b/LLama/Sampling/Tokens/ITokenDataProcessor.cs deleted file mode 100644 index e6679cc2..00000000 --- a/LLama/Sampling/Tokens/ITokenDataProcessor.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -using llama_token = Int32; - -/// -/// Processes token logits before sampling, applying penalties to certain tokens -/// -public interface ITokenDataProcessor - : IDisposable -{ - /// - /// Process token logits in a LLamaTokenDataArray - /// - /// The context this is operating in - /// The token data array to process - /// The most recent tokens output - /// LLamaTokenDataArray, created from logits - void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens); - - /// - /// Inform this process when a token is accepted by the model - /// - /// - /// - void AcceptToken(SafeLLamaContextHandle ctx, int token); - - /// - /// Reset all internal sampling state - /// - void Reset(); -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/LocallyTypicalSampling.cs b/LLama/Sampling/Tokens/LocallyTypicalSampling.cs deleted file mode 100644 index 3f602c9a..00000000 --- a/LLama/Sampling/Tokens/LocallyTypicalSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. -/// -public sealed class LocallyTypicalSampling - : ITokenDataProcessor -{ - /// - /// P value for locally typical sampling - /// - public float P { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.LocallyTypical(ctx, P, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/MinPSampling.cs b/LLama/Sampling/Tokens/MinPSampling.cs deleted file mode 100644 index c3adf026..00000000 --- a/LLama/Sampling/Tokens/MinPSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 -/// -public sealed class MinPSampling - : ITokenDataProcessor -{ - /// - /// All tokens with probability greater than this will be kept - /// - public float P { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.MinP(ctx, P, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/RepetitionPenalty.cs b/LLama/Sampling/Tokens/RepetitionPenalty.cs deleted file mode 100644 index 3cfdbcd4..00000000 --- a/LLama/Sampling/Tokens/RepetitionPenalty.cs +++ /dev/null @@ -1,77 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. -/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. -/// -public sealed class RepetitionPenalty - : ITokenDataProcessor -{ - private float _alphaFreq; - private float _alphaPresence; - - /// - /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 - /// - public float RepeatPenalty { get; set; } = 1.1f; - - /// - /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
- /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text - /// so far, decreasing the model's likelihood to repeat the same line verbatim. - ///
- public float AlphaFrequency - { - get => _alphaFreq; - set - { - if (value < -2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); - if (value > 2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); - _alphaFreq = value; - } - } - - /// - /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
- /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the - /// text so far, increasing the model's likelihood to talk about new topics. - ///
- public float AlphaPresence - { - get => _alphaPresence; - set - { - if (value < -2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); - if (value > 2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); - _alphaPresence = value; - } - } - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TailFreeSampling.cs b/LLama/Sampling/Tokens/TailFreeSampling.cs deleted file mode 100644 index 8e9fb2b5..00000000 --- a/LLama/Sampling/Tokens/TailFreeSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. -/// -public sealed class TailFreeSampling - : ITokenDataProcessor -{ - /// - /// Z value for tail free sampling - /// - public float Z { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.TailFree(ctx, Z, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TemperatureSampling.cs b/LLama/Sampling/Tokens/TemperatureSampling.cs deleted file mode 100644 index 0186f275..00000000 --- a/LLama/Sampling/Tokens/TemperatureSampling.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Sample with temperature. -/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual -/// -public sealed class TemperatureSampling - : ITokenDataProcessor -{ - /// - /// Temperature value to apply - /// - public float Temperature { get; set; } = 0.5f; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.Temperature(ctx, Temperature); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopKSampling.cs b/LLama/Sampling/Tokens/TopKSampling.cs deleted file mode 100644 index 3f797c85..00000000 --- a/LLama/Sampling/Tokens/TopKSampling.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Sample with TopK, removing all by the K most likely tokens. -/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -/// -public sealed class TopKSampling - : ITokenDataProcessor -{ - /// - /// Number of tokens to keep - /// - public int Count { get; set; } - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.TopK(ctx, Count); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopPSampling.cs b/LLama/Sampling/Tokens/TopPSampling.cs deleted file mode 100644 index 577ce3bc..00000000 --- a/LLama/Sampling/Tokens/TopPSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -/// -public sealed class TopPSampling - : ITokenDataProcessor -{ - /// - /// P valies for TopP - /// - public float P { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.TopP(ctx, P, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file From f669a4f5a70bbad5c5341199168244f01ee5cdeb Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Sun, 10 Dec 2023 09:34:11 -0600 Subject: [PATCH 12/22] Update the Chinese chat sample to use new ChatSession integration --- .../Assets/chat-with-kunkun-chinese.json | 24 +++ LLama.Examples/Examples/ChatChineseGB2312.cs | 153 ++++++++++++------ LLama.Examples/Examples/Runner.cs | 2 +- 3 files changed, 129 insertions(+), 50 deletions(-) create mode 100644 LLama.Examples/Assets/chat-with-kunkun-chinese.json diff --git a/LLama.Examples/Assets/chat-with-kunkun-chinese.json b/LLama.Examples/Assets/chat-with-kunkun-chinese.json new file mode 100644 index 00000000..30112327 --- /dev/null +++ b/LLama.Examples/Assets/chat-with-kunkun-chinese.json @@ -0,0 +1,24 @@ +{ + "messages": [ + { + "author_role": "System", + "content": "������һ������û��ĶԻ��������������һ���ڸ����涼ӵ�зḻ�������������dz����ڻش��û�������Ͱ����û���?" + }, + { + "author_role": "User", + "content": "��ã�������?" + }, + { + "author_role": "Assistant", + "content": "��ã���ʲô���ܰ��������" + }, + { + "author_role": "User", + "content": "�й����׶����������У�" + }, + { + "author_role": "Assistant", + "content": "��������˭��" + } + ] +} diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index ff27b962..d250a454 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -1,69 +1,124 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; +using System.Text; using LLama.Common; -namespace LLama.Examples.Examples +namespace LLama.Examples.Examples; + +public class ChatChineseGB2312 { - public class ChatChineseGB2312 + private static string ConvertEncoding(string input, Encoding original, Encoding target) + { + byte[] bytes = original.GetBytes(input); + var convertedBytes = Encoding.Convert(original, target, bytes); + return target.GetString(convertedBytes); + } + + public static async Task Run() { - private static string ConvertFromEncodingToAnother(string input, Encoding original, Encoding target) + // Register provider for GB2312 encoding + Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" + + " to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers."); + Console.ForegroundColor = ConsoleColor.White; + + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5, + Encoding = Encoding.UTF8 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + ChatSession session; + if (Directory.Exists("Assets/chat-with-kunkun-chinese")) + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Loading session from disk."); + Console.ForegroundColor = ConsoleColor.White; + + session = new ChatSession(executor); + session.LoadSession("Assets/chat-with-kunkun-chinese"); + } + else { - byte[] bytes = original.GetBytes(input); - var convertedBytes = Encoding.Convert(original, target, bytes); - return target.GetString(convertedBytes); + var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + session = new ChatSession(executor, chatHistory); } - public static async Task Run() + session + .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户")) + .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + // User and Assistant in Chinese (User is: 用户, Assistant is: 坤坤) + new string[] { "用户:", "坤坤:" }, + redundancyLength: 8)); + + InferenceParams inferenceParams = new InferenceParams() { - Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); // Register gb2312 encoding - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-kunkun-chinese.txt", encoding: Encoding.GetEncoding("gb2312")).Trim(); - prompt = ConvertFromEncodingToAnother(prompt, Encoding.GetEncoding("gb2312"), Encoding.UTF8); + Temperature = 0.9f, + AntiPrompts = new List { "用户:" } + }; - var parameters = new ModelParams(modelPath) - { - ContextSize = 1024, - Seed = 1337, - GpuLayerCount = 20, - Encoding = Encoding.UTF8 - }; - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var executor = new InteractiveExecutor(context); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); - var session = new ChatSession(executor).WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户")); + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" + - " to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers."); - Console.ForegroundColor = ConsoleColor.White; + while (userInput != "exit") + { + // Convert the encoding from gb2312 to utf8 for the language model + // and later saving to the history json file. + userInput = ConvertEncoding(userInput, Encoding.GetEncoding("gb2312"), Encoding.UTF8); - // show the prompt - Console.Write(prompt); - while (true) + if (userInput == "save") { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() + session.SaveSession("Assets/chat-with-kunkun-chinese"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session saved."); + } + else if (userInput == "regenerate") + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Regenerating last response ..."); + + await foreach ( + var text + in session.RegenerateAssistantMessageAsync( + inferenceParams)) { - Temperature = 0.3f, - TopK = 5, - TopP = 0.85f, - AntiPrompts = new List { "用户:" }, - MaxTokens = 2048, - RepeatPenalty = 1.05f - })) + Console.ForegroundColor = ConsoleColor.White; + + // Convert the encoding from utf8 to gb2312 for the console output. + Console.Write(ConvertEncoding(text, Encoding.UTF8, Encoding.GetEncoding("gb2312"))); + } + } + else + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) { - //Console.Write(text); - Console.Write(ConvertFromEncodingToAnother(text, Encoding.UTF8, Encoding.GetEncoding("gb2312"))); + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); } - - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; } } } diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index 0a37dcba..3d9858e1 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -9,6 +9,7 @@ public class Runner { "Run a chat session with history.", ChatSessionWithHistory.Run }, { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, + { "Run a chat session in Chinese GB2312 encoding", ChatChineseGB2312.Run }, { "Interactive mode chat by using executor.", InteractiveModeExecute.Run }, { "Instruct mode chat by using executor.", InstructModeExecute.Run }, { "Stateless mode chat by using executor.", StatelessModeExecute.Run }, @@ -24,7 +25,6 @@ public class Runner { "Coding Assistant.", CodingAssistant.Run }, { "Batch Decoding.", BatchedDecoding.Run }, { "SK Kernel Memory.", KernelMemory.Run }, - { "Chinese gb2312 chat", ChatChineseGB2312.Run }, { "Exit", async () => Environment.Exit(0) } }; From 29c5c6e93c806aa5f719c25c9c1987690419999f Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Sun, 10 Dec 2023 09:34:32 -0600 Subject: [PATCH 13/22] Update the StatefulChatService to use new ChatSession integration --- LLama.WebAPI/Services/StatefulChatService.cs | 35 +++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index f1eb3538..f45c98ee 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -11,8 +11,7 @@ public class StatefulChatService : IDisposable private readonly LLamaContext _context; private bool _continue = false; - private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" - + "User: "; + private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision."; public StatefulChatService(IConfiguration configuration) { @@ -25,7 +24,9 @@ public class StatefulChatService : IDisposable using var weights = LLamaWeights.LoadFromFile(@params); _context = new LLamaContext(weights, @params); + _session = new ChatSession(new InteractiveExecutor(_context)); + _session.History.AddMessage(Common.AuthorRole.System, SystemPrompt); } public void Dispose() @@ -35,10 +36,8 @@ public class StatefulChatService : IDisposable public async Task Send(SendMessageInput input) { - var userInput = input.Text; if (!_continue) { - userInput = SystemPrompt + userInput; Console.Write(SystemPrompt); _continue = true; } @@ -47,11 +46,14 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() - { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, - }); + var outputs = _session.ChatAsync( + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), + new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + var result = ""; await foreach (var output in outputs) { @@ -64,10 +66,8 @@ public class StatefulChatService : IDisposable public async IAsyncEnumerable SendStream(SendMessageInput input) { - var userInput = input.Text; if (!_continue) { - userInput = SystemPrompt + userInput; Console.Write(SystemPrompt); _continue = true; } @@ -76,11 +76,14 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() - { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, - }); + var outputs = _session.ChatAsync( + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text) + , new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + await foreach (var output in outputs) { Console.Write(output); From 8fb447681371e1a77bd620757e2b608a41860bbf Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Dec 2023 06:20:24 +0000 Subject: [PATCH 14/22] build(deps): bump xunit.runner.visualstudio from 2.5.4 to 2.5.5 Bumps [xunit.runner.visualstudio](https://github.com/xunit/visualstudio.xunit) from 2.5.4 to 2.5.5. - [Release notes](https://github.com/xunit/visualstudio.xunit/releases) - [Commits](https://github.com/xunit/visualstudio.xunit/compare/2.5.4...2.5.5) --- updated-dependencies: - dependency-name: xunit.runner.visualstudio dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- LLama.Unittest/LLama.Unittest.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index fbaee5ed..b3092488 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -16,7 +16,7 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive all From 836f071cd0aebf1d19cd1eb620e633f8f7e66567 Mon Sep 17 00:00:00 2001 From: Rinne Date: Mon, 11 Dec 2023 22:21:54 +0800 Subject: [PATCH 15/22] fix: Chinese example. --- LLama.Examples/Assets/chat-with-kunkun-chinese.json | 10 +++++----- LLama.Examples/Examples/ChatChineseGB2312.cs | 6 +----- LLama.Examples/LLama.Examples.csproj | 3 +++ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/LLama.Examples/Assets/chat-with-kunkun-chinese.json b/LLama.Examples/Assets/chat-with-kunkun-chinese.json index 30112327..cae03029 100644 --- a/LLama.Examples/Assets/chat-with-kunkun-chinese.json +++ b/LLama.Examples/Assets/chat-with-kunkun-chinese.json @@ -2,23 +2,23 @@ "messages": [ { "author_role": "System", - "content": "������һ������û��ĶԻ��������������һ���ڸ����涼ӵ�зḻ�������������dz����ڻش��û�������Ͱ����û���?" + "content": "下面是一段你和用户的对话,你叫坤坤,是一个在各方面都拥有丰富经验的助理,你非常乐于回答用户的问题和帮助用户。" }, { "author_role": "User", - "content": "��ã�������?" + "content": "你好,坤坤。" }, { "author_role": "Assistant", - "content": "��ã���ʲô���ܰ��������" + "content": "你好,有什么我能帮助你的吗?" }, { "author_role": "User", - "content": "�й����׶����������У�" + "content": "中国的首都是哪座城市?" }, { "author_role": "Assistant", - "content": "��������˭��" + "content": "中国的首都是北京市。" } ] } diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index d250a454..bb3d3f80 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -55,11 +55,7 @@ public class ChatChineseGB2312 } session - .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户")) - .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( - // User and Assistant in Chinese (User is: 用户, Assistant is: 坤坤) - new string[] { "用户:", "坤坤:" }, - redundancyLength: 8)); + .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); InferenceParams inferenceParams = new InferenceParams() { diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index 958bb6c4..d3acbb82 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -44,6 +44,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest From 85dc43dde0fede6b7d8cc9905f0e9f5e9b36b221 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Dec 2023 14:26:14 +0000 Subject: [PATCH 16/22] build(deps): bump xunit from 2.6.2 to 2.6.3 Bumps [xunit](https://github.com/xunit/xunit) from 2.6.2 to 2.6.3. - [Commits](https://github.com/xunit/xunit/compare/2.6.2...2.6.3) --- updated-dependencies: - dependency-name: xunit dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- LLama.Unittest/LLama.Unittest.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index b3092488..4692f101 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -15,7 +15,7 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive all From fb75e06293d904e13a930759dc9b1661fe61a15e Mon Sep 17 00:00:00 2001 From: Rinne Date: Mon, 11 Dec 2023 22:27:47 +0800 Subject: [PATCH 17/22] fix: output prefix of Chinese example. --- LLama.Examples/Examples/ChatChineseGB2312.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index bb3d3f80..3a9fe6c7 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -67,6 +67,8 @@ public class ChatChineseGB2312 Console.WriteLine("The chat session has started."); // show the prompt + Console.ForegroundColor = ConsoleColor.White; + Console.Write("用户:"); Console.ForegroundColor = ConsoleColor.Green; string userInput = Console.ReadLine() ?? ""; From df66d7e0c6715d69e3669ec935452d1e2c761765 Mon Sep 17 00:00:00 2001 From: xbotter Date: Tue, 12 Dec 2023 19:08:34 +0800 Subject: [PATCH 18/22] Upgrade unittest target framework to .net8 --- LLama.Unittest/LLama.Unittest.csproj | 4 ++-- LLama.Unittest/ModelsParamsTests.cs | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 4692f101..a206a23d 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -1,7 +1,7 @@ - + - net6.0 + net8.0 LLama.Unittest enable AnyCPU;x64 diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index aec4b5a3..0eb763c3 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -1,4 +1,5 @@ using LLama.Common; +using System.Text.Json; namespace LLama.Unittest { @@ -7,6 +8,8 @@ namespace LLama.Unittest [Fact] public void SerializeRoundTripSystemTextJson() { + var options = new JsonSerializerOptions(); + options.Converters.Add(new EncodingConverter()); var expected = new ModelParams("abc/123") { BatchSize = 17, @@ -16,8 +19,8 @@ namespace LLama.Unittest TensorSplits = { [0] = 3 } }; - var json = System.Text.Json.JsonSerializer.Serialize(expected); - var actual = System.Text.Json.JsonSerializer.Deserialize(json)!; + var json = JsonSerializer.Serialize(expected, options); + var actual = JsonSerializer.Deserialize(json, options)!; // Cannot compare splits with default equality, check they are sequence equal and then set to null Assert.Equal((IEnumerable)expected.TensorSplits, expected.TensorSplits); From e6148c952e7630116725f86c97b9120601e1f417 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 12 Dec 2023 15:50:32 +0000 Subject: [PATCH 19/22] Fixed encoding of `Encoding` --- LLama.Unittest/ModelsParamsTests.cs | 11 +++++++---- LLama/Common/ModelParams.cs | 30 ++++++++++++----------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index 0eb763c3..34c9c21b 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -8,8 +8,6 @@ namespace LLama.Unittest [Fact] public void SerializeRoundTripSystemTextJson() { - var options = new JsonSerializerOptions(); - options.Converters.Add(new EncodingConverter()); var expected = new ModelParams("abc/123") { BatchSize = 17, @@ -19,14 +17,19 @@ namespace LLama.Unittest TensorSplits = { [0] = 3 } }; - var json = JsonSerializer.Serialize(expected, options); - var actual = JsonSerializer.Deserialize(json, options)!; + var json = JsonSerializer.Serialize(expected); + var actual = JsonSerializer.Deserialize(json)!; // Cannot compare splits with default equality, check they are sequence equal and then set to null Assert.Equal((IEnumerable)expected.TensorSplits, expected.TensorSplits); actual.TensorSplits = null!; expected.TensorSplits = null!; + // Check encoding is the same + var b1 = expected.Encoding.GetBytes("Hello"); + var b2 = actual.Encoding.GetBytes("Hello"); + Assert.True(b1.SequenceEqual(b2)); + Assert.Equal(expected, actual); } diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index f1cef072..f228d7a3 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -92,9 +92,19 @@ namespace LLama.Common /// public bool VocabOnly { get; set; } + /// + /// `Encoding` cannot be directly JSON serialized, instead store the name as a string which can + /// + [JsonPropertyName("Encoding")] + private string EncodingName { get; set; } = Encoding.UTF8.WebName; + /// - [JsonConverter(typeof(EncodingConverter))] - public Encoding Encoding { get; set; } = Encoding.UTF8; + [JsonIgnore] + public Encoding Encoding + { + get => Encoding.GetEncoding(EncodingName); + set => EncodingName = value.WebName; + } /// /// @@ -113,22 +123,6 @@ namespace LLama.Common } } - internal class EncodingConverter - : JsonConverter - { - public override Encoding? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var name = reader.GetString(); - if (name == null) - return null; - return Encoding.GetEncoding(name); - } - - public override void Write(Utf8JsonWriter writer, Encoding value, JsonSerializerOptions options) - { - writer.WriteStringValue(value.WebName); - } - } internal class TensorSplitsCollectionConverter : JsonConverter From 01c7f1b4dab4701f964ac2d4b2c785e0a5474c3c Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 12 Dec 2023 16:56:51 +0000 Subject: [PATCH 20/22] Update LLama/Common/ModelParams.cs --- LLama/Common/ModelParams.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index f228d7a3..25e638ad 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -96,6 +96,7 @@ namespace LLama.Common /// `Encoding` cannot be directly JSON serialized, instead store the name as a string which can /// [JsonPropertyName("Encoding")] + [JsonInclude] private string EncodingName { get; set; } = Encoding.UTF8.WebName; /// From 0b8422ea7f1893bb5929e64a2815d58243fe2c8b Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 13 Dec 2023 14:46:43 +0000 Subject: [PATCH 21/22] Added AVX and AVX2 to MacOS x86_64 builds --- .github/workflows/compile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index ec5f725e..a8169be1 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -140,7 +140,7 @@ jobs: - build: 'arm64' defines: '-DCMAKE_OSX_ARCHITECTURES=arm64' - build: 'x64' - defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF' + defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF -DLLAMA_AVX=ON -DLLAMA_AVX2=ON' runs-on: macos-latest steps: - uses: actions/checkout@v3 From 340bbbcf486aeb516ee51f9acfa50dd4be12cfa3 Mon Sep 17 00:00:00 2001 From: xbotter Date: Thu, 14 Dec 2023 09:10:31 +0800 Subject: [PATCH 22/22] Move JSON converter for TensorSplitsCollection --- LLama/Abstractions/IModelParams.cs | 24 ++++++++++++++++++++++++ LLama/Common/ModelParams.cs | 17 ----------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 2ecfe49c..4a3dde7a 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -3,6 +3,9 @@ using System.Buffers; using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using LLama.Common; using LLama.Native; namespace LLama.Abstractions @@ -105,6 +108,7 @@ namespace LLama.Abstractions /// /// A fixed size array to set the tensor splits across multiple GPUs /// + [JsonConverter(typeof(TensorSplitsCollectionConverter))] public sealed class TensorSplitsCollection : IEnumerable { @@ -174,4 +178,24 @@ namespace LLama.Abstractions } #endregion } + + /// + /// A JSON converter for + /// + public class TensorSplitsCollectionConverter + : JsonConverter + { + /// + public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); + return new TensorSplitsCollection(arr); + } + + /// + public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value.Splits, options); + } + } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 25e638ad..cecd655a 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -59,7 +59,6 @@ namespace LLama.Common public bool EmbeddingMode { get; set; } /// - [JsonConverter(typeof(TensorSplitsCollectionConverter))] public TensorSplitsCollection TensorSplits { get; set; } = new(); /// @@ -123,20 +122,4 @@ namespace LLama.Common ModelPath = ""; } } - - - internal class TensorSplitsCollectionConverter - : JsonConverter - { - public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); - return new TensorSplitsCollection(arr); - } - - public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) - { - JsonSerializer.Serialize(writer, value.Splits, options); - } - } }