| @@ -1,7 +1,9 @@ | |||
| using System.Collections.Generic; | |||
| using LLama.Common; | |||
| using LLama.Control; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| using LLama.Transform; | |||
| namespace LLama.Abstractions | |||
| { | |||
| @@ -114,5 +116,15 @@ namespace LLama.Abstractions | |||
| /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | |||
| /// </summary> | |||
| ISamplingPipeline? SamplingPipeline { get; set; } | |||
| /// <summary> | |||
| /// Set a custom generation control to use. <b>If this is set antiprompt will be ignored!</b> | |||
| /// </summary> | |||
| IGenerationControl GenerationControl { get; set; } | |||
| /// <summary> | |||
| /// Set a custom tokenizer to use. | |||
| /// </summary> | |||
| ITokenizer Tokenizer { get; set; } | |||
| } | |||
| } | |||
| @@ -18,8 +18,8 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| /// <param name="text">Your prompt</param> | |||
| /// <param name="inferenceParams">Any additional parameters</param> | |||
| /// <param name="token">A cancellation token.</param> | |||
| /// <param name="cancellationToken">A cancellation token.</param> | |||
| /// <returns></returns> | |||
| IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); | |||
| IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default); | |||
| } | |||
| } | |||
| @@ -3,6 +3,9 @@ using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| using LLama.Control; | |||
| using LLama.Transform; | |||
| using System.Text; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -80,6 +83,12 @@ namespace LLama.Common | |||
| /// <inheritdoc /> | |||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||
| /// <inheritdoc /> | |||
| public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl(); | |||
| /// <inheritdoc /> | |||
| public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8); | |||
| } | |||
| /// <summary> | |||
| @@ -1,7 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace LLama | |||
| namespace LLama.Control | |||
| { | |||
| /// <summary> | |||
| /// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts | |||
| @@ -0,0 +1,42 @@ | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Control | |||
| { | |||
| /// <summary> | |||
| /// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited. | |||
| /// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b> | |||
| /// If you use it in a session, please don't reuse it for another session unless you intend to do so. | |||
| /// </summary> | |||
| public sealed class DefaultGenerationControl: IGenerationControl | |||
| { | |||
| private AntipromptProcessor _antipromptProcessor; | |||
| /// <summary> | |||
| /// <inheritdoc/> | |||
| /// </summary> | |||
| public DefaultGenerationControl() | |||
| { | |||
| _antipromptProcessor = new AntipromptProcessor(); | |||
| } | |||
| /// <summary> | |||
| /// <inheritdoc/> | |||
| /// </summary> | |||
| public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) | |||
| { | |||
| _antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts); | |||
| return _antipromptProcessor.Add(lastOutputText); | |||
| } | |||
| /// <summary> | |||
| /// <inheritdoc/> | |||
| /// </summary> | |||
| public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds) | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,31 @@ | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Control | |||
| { | |||
| /// <summary> | |||
| /// Control the text generation of LLama Executors. | |||
| /// </summary> | |||
| public interface IGenerationControl | |||
| { | |||
| /// <summary> | |||
| /// Use the last output text to determine if the generation should stop. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="lastOutputText"></param> | |||
| /// <returns></returns> | |||
| bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText); | |||
| /// <summary> | |||
| /// Use the last output ids to determine if the generation should stop. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="lastOutputIds"></param> | |||
| /// <returns></returns> | |||
| bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds); | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using LLama.Common; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using LLama.Transform; | |||
| using Microsoft.Extensions.Logging; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| @@ -8,6 +8,7 @@ using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| using LLama.Control; | |||
| using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| @@ -49,7 +50,7 @@ namespace LLama | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| // Ensure the context from last time is disposed (it always hould be) | |||
| // Ensure the context from last time is disposed (it always should be) | |||
| if (!Context.NativeHandle.IsClosed) | |||
| Context.Dispose(); | |||
| @@ -57,48 +58,53 @@ namespace LLama | |||
| using var context = _weights.CreateContext(_params, _logger); | |||
| Context = context; | |||
| await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return item; | |||
| } | |||
| } | |||
| public static async IAsyncEnumerable<string> InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| // Sanity check inference params | |||
| inferenceParams ??= new InferenceParams(); | |||
| if (inferenceParams.TokensKeep > Context.ContextSize) | |||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); | |||
| // Create decoders for the token stream | |||
| var decoder = new StreamingTokenDecoder(Context); | |||
| var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts); | |||
| if (inferenceParams.TokensKeep > context.ContextSize) | |||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})"); | |||
| // Keep track of the last N tokens emitted | |||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | |||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount); | |||
| var lastTokens = new List<llama_token>(repeat_last_n); | |||
| for (var i = 0; i < repeat_last_n; i++) | |||
| lastTokens.Add(0); | |||
| // Tokenize the prompt | |||
| var tokens = Context.Tokenize(prompt).ToList(); | |||
| var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList(); | |||
| lastTokens.AddRange(tokens); | |||
| var n_past = 1 + tokens.Count; | |||
| // Evaluate the prompt | |||
| await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) | |||
| await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken) | |||
| .ConfigureAwait(false); | |||
| // Begin loop, evaluating one token at a time | |||
| var mu = (float?)null; | |||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||
| for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||
| { | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||
| id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens); | |||
| } | |||
| else | |||
| { | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| id = Context.Sample( | |||
| id = context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| @@ -106,12 +112,11 @@ namespace LLama | |||
| } | |||
| // Decode this token into text | |||
| decoder.Add(id); | |||
| var decoded = decoder.Read(); | |||
| var decoded = inferenceParams.Tokenizer.Detokenize(context, id); | |||
| yield return decoded; | |||
| // Check if any of the antiprompts have been generated | |||
| if (antiprocessor.Add(decoded)) | |||
| // Check if the generation should stop | |||
| if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded)) | |||
| break; | |||
| lastTokens.Add(id); | |||
| @@ -120,19 +125,19 @@ namespace LLama | |||
| // when run out of context | |||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | |||
| if (n_past + tokens.Count >= Context.ContextSize) | |||
| if (n_past + tokens.Count >= context.ContextSize) | |||
| { | |||
| var n_left = n_past - inferenceParams.TokensKeep - 1; | |||
| var n_discard = n_left / 2; | |||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||
| NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||
| NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||
| NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||
| n_past -= n_discard; | |||
| } | |||
| // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | |||
| n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) | |||
| n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken) | |||
| .ConfigureAwait(false); | |||
| } | |||
| } | |||
| @@ -0,0 +1,31 @@ | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Threading; | |||
| namespace LLama | |||
| { | |||
| /// <summary> | |||
| /// A class to execute text completion task. | |||
| /// </summary> | |||
| public class TextCompletion | |||
| { | |||
| public string Execute(string prompt, IInferenceParams? inferenceParams = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public async IAsyncEnumerable<string> StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,53 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Transform | |||
| { | |||
| /// <summary> | |||
| /// The default tokenizer of LLamaSharp. This class should not be inherited. | |||
| /// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b> | |||
| /// If you use it in a session, please don't reuse it for another session unless you intend to do so. | |||
| /// </summary> | |||
| public sealed class DefaultTokenizer: ITokenizer | |||
| { | |||
| private Encoding _encoding; | |||
| private StreamingTokenDecoder _tokenDecoder; | |||
| /// <summary> | |||
| /// Initialize a new tokenizer with the specified encoding. | |||
| /// </summary> | |||
| /// <param name="encoding"></param> | |||
| public DefaultTokenizer(Encoding encoding) | |||
| { | |||
| _encoding = encoding; | |||
| _tokenDecoder = new StreamingTokenDecoder(encoding); | |||
| } | |||
| /// <summary> | |||
| /// <inheritdoc/> | |||
| /// </summary> | |||
| public IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false) | |||
| { | |||
| return context.Tokenize(text, addBos, special); | |||
| } | |||
| /// <summary> | |||
| /// <inheritdoc/> | |||
| /// </summary> | |||
| public string Detokenize(LLamaContext context, int token) | |||
| { | |||
| _tokenDecoder.Add(token, context.NativeHandle.ModelHandle); | |||
| return _tokenDecoder.Read(); | |||
| } | |||
| /// <summary> | |||
| /// <inheritdoc/> | |||
| /// </summary> | |||
| public string Detokenize(LLamaContext context, IEnumerable<int> tokens) | |||
| { | |||
| _tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle); | |||
| return _tokenDecoder.Read(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,15 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Transform | |||
| { | |||
| public interface ITokenizer | |||
| { | |||
| IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false); | |||
| string Detokenize(LLamaContext context, int token); | |||
| string Detokenize(LLamaContext context, IEnumerable<int> tokens); | |||
| } | |||
| } | |||
| @@ -6,14 +6,14 @@ using System.Text; | |||
| using LLama.Extensions; | |||
| using LLama.Native; | |||
| namespace LLama | |||
| namespace LLama.Transform | |||
| { | |||
| /// <summary> | |||
| /// Decodes a stream of tokens into a stream of characters | |||
| /// </summary> | |||
| public sealed class StreamingTokenDecoder | |||
| { | |||
| private readonly SafeLlamaModelHandle _weights; | |||
| private readonly SafeLlamaModelHandle? _weights; | |||
| private readonly Decoder _decoder; | |||
| private readonly List<char> _characters = new(); | |||
| @@ -29,8 +29,8 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="weights">Model weights</param> | |||
| public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) | |||
| : this(encoding, weights.NativeHandle) | |||
| public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null) | |||
| : this(encoding, weights?.NativeHandle) | |||
| { | |||
| } | |||
| @@ -69,14 +69,19 @@ namespace LLama | |||
| /// Add a single token to the decoder | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| public void Add(int token) | |||
| public void Add(int token, SafeLlamaModelHandle? weights = null) | |||
| { | |||
| weights ??= _weights; | |||
| if(weights is null) | |||
| { | |||
| throw new NullReferenceException("No weights provided for StreamingTokenDecoder."); | |||
| } | |||
| var charsArr = ArrayPool<char>.Shared.Rent(16); | |||
| var bytesArr = ArrayPool<byte>.Shared.Rent(16); | |||
| try | |||
| { | |||
| // Convert this token into bytes | |||
| var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; | |||
| var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length; | |||
| // Convert those bytes into characters | |||
| var bytesOffset = 0; | |||
| @@ -133,10 +138,10 @@ namespace LLama | |||
| /// Add all tokens in the given enumerable | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| public void AddRange(IEnumerable<int> tokens) | |||
| public void AddRange(IEnumerable<int> tokens, SafeLlamaModelHandle? weights = null) | |||
| { | |||
| foreach (var item in tokens) | |||
| Add(item); | |||
| Add(item, weights); | |||
| } | |||
| /// <summary> | |||