| Author | SHA1 | Message | Date |
|---|---|---|---|
|
|
4f44e3b198
|
refactor: init some refactorings for experiment. | 1 year ago |
| @@ -1,7 +1,9 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Control; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | using LLama.Sampling; | ||||
| using LLama.Transform; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -114,5 +116,15 @@ namespace LLama.Abstractions | |||||
| /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | ||||
| /// </summary> | /// </summary> | ||||
| ISamplingPipeline? SamplingPipeline { get; set; } | ISamplingPipeline? SamplingPipeline { get; set; } | ||||
| /// <summary> | |||||
| /// Set a custom generation control to use. <b>If this is set antiprompt will be ignored!</b> | |||||
| /// </summary> | |||||
| IGenerationControl GenerationControl { get; set; } | |||||
| /// <summary> | |||||
| /// Set a custom tokenizer to use. | |||||
| /// </summary> | |||||
| ITokenizer Tokenizer { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -18,8 +18,8 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="text">Your prompt</param> | /// <param name="text">Your prompt</param> | ||||
| /// <param name="inferenceParams">Any additional parameters</param> | /// <param name="inferenceParams">Any additional parameters</param> | ||||
| /// <param name="token">A cancellation token.</param> | |||||
| /// <param name="cancellationToken">A cancellation token.</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); | |||||
| IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default); | |||||
| } | } | ||||
| } | } | ||||
| @@ -3,6 +3,9 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | using LLama.Sampling; | ||||
| using LLama.Control; | |||||
| using LLama.Transform; | |||||
| using System.Text; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| @@ -80,6 +83,12 @@ namespace LLama.Common | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public ISamplingPipeline? SamplingPipeline { get; set; } | public ISamplingPipeline? SamplingPipeline { get; set; } | ||||
| /// <inheritdoc /> | |||||
| public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl(); | |||||
| /// <inheritdoc /> | |||||
| public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,7 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| namespace LLama | |||||
| namespace LLama.Control | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts | /// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts | ||||
| @@ -0,0 +1,42 @@ | |||||
| using LLama.Abstractions; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Control | |||||
| { | |||||
| /// <summary> | |||||
| /// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited. | |||||
| /// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b> | |||||
| /// If you use it in a session, please don't reuse it for another session unless you intend to do so. | |||||
| /// </summary> | |||||
| public sealed class DefaultGenerationControl: IGenerationControl | |||||
| { | |||||
| private AntipromptProcessor _antipromptProcessor; | |||||
| /// <summary> | |||||
| /// <inheritdoc/> | |||||
| /// </summary> | |||||
| public DefaultGenerationControl() | |||||
| { | |||||
| _antipromptProcessor = new AntipromptProcessor(); | |||||
| } | |||||
| /// <summary> | |||||
| /// <inheritdoc/> | |||||
| /// </summary> | |||||
| public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) | |||||
| { | |||||
| _antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts); | |||||
| return _antipromptProcessor.Add(lastOutputText); | |||||
| } | |||||
| /// <summary> | |||||
| /// <inheritdoc/> | |||||
| /// </summary> | |||||
| public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| using LLama.Abstractions; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Control | |||||
| { | |||||
| /// <summary> | |||||
| /// Control the text generation of LLama Executors. | |||||
| /// </summary> | |||||
| public interface IGenerationControl | |||||
| { | |||||
| /// <summary> | |||||
| /// Use the last output text to determine if the generation should stop. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="inferenceParams"></param> | |||||
| /// <param name="lastOutputText"></param> | |||||
| /// <returns></returns> | |||||
| bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText); | |||||
| /// <summary> | |||||
| /// Use the last output ids to determine if the generation should stop. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="inferenceParams"></param> | |||||
| /// <param name="lastOutputIds"></param> | |||||
| /// <returns></returns> | |||||
| bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds); | |||||
| } | |||||
| } | |||||
| @@ -2,6 +2,7 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Transform; | |||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| @@ -8,6 +8,7 @@ using System.Threading; | |||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | using LLama.Sampling; | ||||
| using LLama.Control; | |||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -49,7 +50,7 @@ namespace LLama | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| // Ensure the context from last time is disposed (it always hould be) | |||||
| // Ensure the context from last time is disposed (it always should be) | |||||
| if (!Context.NativeHandle.IsClosed) | if (!Context.NativeHandle.IsClosed) | ||||
| Context.Dispose(); | Context.Dispose(); | ||||
| @@ -57,48 +58,53 @@ namespace LLama | |||||
| using var context = _weights.CreateContext(_params, _logger); | using var context = _weights.CreateContext(_params, _logger); | ||||
| Context = context; | Context = context; | ||||
| await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken)) | |||||
| { | |||||
| yield return item; | |||||
| } | |||||
| } | |||||
| public static async IAsyncEnumerable<string> InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| // Sanity check inference params | // Sanity check inference params | ||||
| inferenceParams ??= new InferenceParams(); | inferenceParams ??= new InferenceParams(); | ||||
| if (inferenceParams.TokensKeep > Context.ContextSize) | |||||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); | |||||
| // Create decoders for the token stream | |||||
| var decoder = new StreamingTokenDecoder(Context); | |||||
| var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts); | |||||
| if (inferenceParams.TokensKeep > context.ContextSize) | |||||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})"); | |||||
| // Keep track of the last N tokens emitted | // Keep track of the last N tokens emitted | ||||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | |||||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount); | |||||
| var lastTokens = new List<llama_token>(repeat_last_n); | var lastTokens = new List<llama_token>(repeat_last_n); | ||||
| for (var i = 0; i < repeat_last_n; i++) | for (var i = 0; i < repeat_last_n; i++) | ||||
| lastTokens.Add(0); | lastTokens.Add(0); | ||||
| // Tokenize the prompt | // Tokenize the prompt | ||||
| var tokens = Context.Tokenize(prompt).ToList(); | |||||
| var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList(); | |||||
| lastTokens.AddRange(tokens); | lastTokens.AddRange(tokens); | ||||
| var n_past = 1 + tokens.Count; | var n_past = 1 + tokens.Count; | ||||
| // Evaluate the prompt | // Evaluate the prompt | ||||
| await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) | |||||
| await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken) | |||||
| .ConfigureAwait(false); | .ConfigureAwait(false); | ||||
| // Begin loop, evaluating one token at a time | // Begin loop, evaluating one token at a time | ||||
| var mu = (float?)null; | var mu = (float?)null; | ||||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | ||||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||||
| for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||||
| { | { | ||||
| llama_token id; | llama_token id; | ||||
| if (inferenceParams.SamplingPipeline is not null) | if (inferenceParams.SamplingPipeline is not null) | ||||
| { | { | ||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||||
| id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // Penalize the generated tokens by various penalties | // Penalize the generated tokens by various penalties | ||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| // Sample a single token | // Sample a single token | ||||
| id = Context.Sample( | |||||
| id = context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | ||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | ||||
| inferenceParams.MinP | inferenceParams.MinP | ||||
| @@ -106,12 +112,11 @@ namespace LLama | |||||
| } | } | ||||
| // Decode this token into text | // Decode this token into text | ||||
| decoder.Add(id); | |||||
| var decoded = decoder.Read(); | |||||
| var decoded = inferenceParams.Tokenizer.Detokenize(context, id); | |||||
| yield return decoded; | yield return decoded; | ||||
| // Check if any of the antiprompts have been generated | |||||
| if (antiprocessor.Add(decoded)) | |||||
| // Check if the generation should stop | |||||
| if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded)) | |||||
| break; | break; | ||||
| lastTokens.Add(id); | lastTokens.Add(id); | ||||
| @@ -120,19 +125,19 @@ namespace LLama | |||||
| // when run out of context | // when run out of context | ||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | ||||
| if (n_past + tokens.Count >= Context.ContextSize) | |||||
| if (n_past + tokens.Count >= context.ContextSize) | |||||
| { | { | ||||
| var n_left = n_past - inferenceParams.TokensKeep - 1; | var n_left = n_past - inferenceParams.TokensKeep - 1; | ||||
| var n_discard = n_left / 2; | var n_discard = n_left / 2; | ||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||||
| NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||||
| NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| n_past -= n_discard; | n_past -= n_discard; | ||||
| } | } | ||||
| // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | ||||
| n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) | |||||
| n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken) | |||||
| .ConfigureAwait(false); | .ConfigureAwait(false); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,31 @@ | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| using System.Threading; | |||||
| namespace LLama | |||||
| { | |||||
| /// <summary> | |||||
| /// A class to execute text completion task. | |||||
| /// </summary> | |||||
| public class TextCompletion | |||||
| { | |||||
| public string Execute(string prompt, IInferenceParams? inferenceParams = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public async IAsyncEnumerable<string> StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,53 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Transform | |||||
| { | |||||
| /// <summary> | |||||
| /// The default tokenizer of LLamaSharp. This class should not be inherited. | |||||
| /// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b> | |||||
| /// If you use it in a session, please don't reuse it for another session unless you intend to do so. | |||||
| /// </summary> | |||||
| public sealed class DefaultTokenizer: ITokenizer | |||||
| { | |||||
| private Encoding _encoding; | |||||
| private StreamingTokenDecoder _tokenDecoder; | |||||
| /// <summary> | |||||
| /// Initialize a new tokenizer with the specified encoding. | |||||
| /// </summary> | |||||
| /// <param name="encoding"></param> | |||||
| public DefaultTokenizer(Encoding encoding) | |||||
| { | |||||
| _encoding = encoding; | |||||
| _tokenDecoder = new StreamingTokenDecoder(encoding); | |||||
| } | |||||
| /// <summary> | |||||
| /// <inheritdoc/> | |||||
| /// </summary> | |||||
| public IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false) | |||||
| { | |||||
| return context.Tokenize(text, addBos, special); | |||||
| } | |||||
| /// <summary> | |||||
| /// <inheritdoc/> | |||||
| /// </summary> | |||||
| public string Detokenize(LLamaContext context, int token) | |||||
| { | |||||
| _tokenDecoder.Add(token, context.NativeHandle.ModelHandle); | |||||
| return _tokenDecoder.Read(); | |||||
| } | |||||
| /// <summary> | |||||
| /// <inheritdoc/> | |||||
| /// </summary> | |||||
| public string Detokenize(LLamaContext context, IEnumerable<int> tokens) | |||||
| { | |||||
| _tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle); | |||||
| return _tokenDecoder.Read(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Transform | |||||
| { | |||||
| public interface ITokenizer | |||||
| { | |||||
| IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false); | |||||
| string Detokenize(LLamaContext context, int token); | |||||
| string Detokenize(LLamaContext context, IEnumerable<int> tokens); | |||||
| } | |||||
| } | |||||
| @@ -6,14 +6,14 @@ using System.Text; | |||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| namespace LLama | |||||
| namespace LLama.Transform | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Decodes a stream of tokens into a stream of characters | /// Decodes a stream of tokens into a stream of characters | ||||
| /// </summary> | /// </summary> | ||||
| public sealed class StreamingTokenDecoder | public sealed class StreamingTokenDecoder | ||||
| { | { | ||||
| private readonly SafeLlamaModelHandle _weights; | |||||
| private readonly SafeLlamaModelHandle? _weights; | |||||
| private readonly Decoder _decoder; | private readonly Decoder _decoder; | ||||
| private readonly List<char> _characters = new(); | private readonly List<char> _characters = new(); | ||||
| @@ -29,8 +29,8 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="encoding">Text encoding to use</param> | /// <param name="encoding">Text encoding to use</param> | ||||
| /// <param name="weights">Model weights</param> | /// <param name="weights">Model weights</param> | ||||
| public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) | |||||
| : this(encoding, weights.NativeHandle) | |||||
| public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null) | |||||
| : this(encoding, weights?.NativeHandle) | |||||
| { | { | ||||
| } | } | ||||
| @@ -69,14 +69,19 @@ namespace LLama | |||||
| /// Add a single token to the decoder | /// Add a single token to the decoder | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="token"></param> | /// <param name="token"></param> | ||||
| public void Add(int token) | |||||
| public void Add(int token, SafeLlamaModelHandle? weights = null) | |||||
| { | { | ||||
| weights ??= _weights; | |||||
| if(weights is null) | |||||
| { | |||||
| throw new NullReferenceException("No weights provided for StreamingTokenDecoder."); | |||||
| } | |||||
| var charsArr = ArrayPool<char>.Shared.Rent(16); | var charsArr = ArrayPool<char>.Shared.Rent(16); | ||||
| var bytesArr = ArrayPool<byte>.Shared.Rent(16); | var bytesArr = ArrayPool<byte>.Shared.Rent(16); | ||||
| try | try | ||||
| { | { | ||||
| // Convert this token into bytes | // Convert this token into bytes | ||||
| var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; | |||||
| var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length; | |||||
| // Convert those bytes into characters | // Convert those bytes into characters | ||||
| var bytesOffset = 0; | var bytesOffset = 0; | ||||
| @@ -133,10 +138,10 @@ namespace LLama | |||||
| /// Add all tokens in the given enumerable | /// Add all tokens in the given enumerable | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| public void AddRange(IEnumerable<int> tokens) | |||||
| public void AddRange(IEnumerable<int> tokens, SafeLlamaModelHandle? weights = null) | |||||
| { | { | ||||
| foreach (var item in tokens) | foreach (var item in tokens) | ||||
| Add(item); | |||||
| Add(item, weights); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||