diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs
index e1e89414..6156ade8 100644
--- a/LLama/Abstractions/IInferenceParams.cs
+++ b/LLama/Abstractions/IInferenceParams.cs
@@ -1,7 +1,9 @@
using System.Collections.Generic;
using LLama.Common;
+using LLama.Control;
using LLama.Native;
using LLama.Sampling;
+using LLama.Transform;
namespace LLama.Abstractions
{
@@ -114,5 +116,15 @@ namespace LLama.Abstractions
/// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored!
///
ISamplingPipeline? SamplingPipeline { get; set; }
+
+ ///
+ /// Set a custom generation control to use. If this is set antiprompt will be ignored!
+ ///
+ IGenerationControl GenerationControl { get; set; }
+
+ ///
+ /// Set a custom tokenizer to use.
+ ///
+ ITokenizer Tokenizer { get; set; }
}
}
\ No newline at end of file
diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs
index ef5453a7..4fd8600d 100644
--- a/LLama/Abstractions/ILLamaExecutor.cs
+++ b/LLama/Abstractions/ILLamaExecutor.cs
@@ -18,8 +18,8 @@ namespace LLama.Abstractions
///
/// Your prompt
/// Any additional parameters
- /// A cancellation token.
+ /// A cancellation token.
///
- IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
+ IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default);
}
}
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index c1f39550..a47408f1 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -3,6 +3,9 @@ using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
+using LLama.Control;
+using LLama.Transform;
+using System.Text;
namespace LLama.Common
{
@@ -80,6 +83,12 @@ namespace LLama.Common
///
public ISamplingPipeline? SamplingPipeline { get; set; }
+
+ ///
+ public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl();
+
+ ///
+ public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8);
}
///
diff --git a/LLama/AntipromptProcessor.cs b/LLama/Control/AntipromptProcessor.cs
similarity index 99%
rename from LLama/AntipromptProcessor.cs
rename to LLama/Control/AntipromptProcessor.cs
index c18c0915..9b5ff987 100644
--- a/LLama/AntipromptProcessor.cs
+++ b/LLama/Control/AntipromptProcessor.cs
@@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;
-namespace LLama
+namespace LLama.Control
{
///
/// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts
diff --git a/LLama/Control/DefaultGenerationControl.cs b/LLama/Control/DefaultGenerationControl.cs
new file mode 100644
index 00000000..fe8f7d94
--- /dev/null
+++ b/LLama/Control/DefaultGenerationControl.cs
@@ -0,0 +1,42 @@
+using LLama.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Control
+{
+ ///
+ /// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited.
+ /// Note that this class has state. The previous outputs feeded to it will affect its control.
+ /// If you use it in a session, please don't reuse it for another session unless you intend to do so.
+ ///
+ public sealed class DefaultGenerationControl: IGenerationControl
+ {
+ private AntipromptProcessor _antipromptProcessor;
+
+ ///
+ ///
+ ///
+ public DefaultGenerationControl()
+ {
+ _antipromptProcessor = new AntipromptProcessor();
+ }
+
+ ///
+ ///
+ ///
+ public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText)
+ {
+ _antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts);
+ return _antipromptProcessor.Add(lastOutputText);
+ }
+
+ ///
+ ///
+ ///
+ public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds)
+ {
+ return false;
+ }
+ }
+}
diff --git a/LLama/Control/IGenerationControl.cs b/LLama/Control/IGenerationControl.cs
new file mode 100644
index 00000000..3e01d284
--- /dev/null
+++ b/LLama/Control/IGenerationControl.cs
@@ -0,0 +1,31 @@
+using LLama.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Control
+{
+ ///
+ /// Control the text generation of LLama Executors.
+ ///
+ public interface IGenerationControl
+ {
+ ///
+ /// Use the last output text to determine if the generation should stop.
+ ///
+ ///
+ ///
+ ///
+ ///
+ bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText);
+
+ ///
+ /// Use the last output ids to determine if the generation should stop.
+ ///
+ ///
+ ///
+ ///
+ ///
+ bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds);
+ }
+}
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index e0fde1ed..7b7ccd24 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -2,6 +2,7 @@
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
+using LLama.Transform;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 831aceb2..75101674 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -8,6 +8,7 @@ using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
+using LLama.Control;
using Microsoft.Extensions.Logging;
namespace LLama
@@ -49,7 +50,7 @@ namespace LLama
///
public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- // Ensure the context from last time is disposed (it always hould be)
+ // Ensure the context from last time is disposed (it always should be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();
@@ -57,48 +58,53 @@ namespace LLama
using var context = _weights.CreateContext(_params, _logger);
Context = context;
+ await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken))
+ {
+ yield return item;
+ }
+ }
+
+ public static async IAsyncEnumerable InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+
// Sanity check inference params
inferenceParams ??= new InferenceParams();
- if (inferenceParams.TokensKeep > Context.ContextSize)
- throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
-
- // Create decoders for the token stream
- var decoder = new StreamingTokenDecoder(Context);
- var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
+ if (inferenceParams.TokensKeep > context.ContextSize)
+ throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})");
// Keep track of the last N tokens emitted
- var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
+ var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);
// Tokenize the prompt
- var tokens = Context.Tokenize(prompt).ToList();
+ var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;
// Evaluate the prompt
- await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
+ await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
// Begin loop, evaluating one token at a time
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
- for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
+ for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
- id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
+ id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
- var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
+ var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
// Sample a single token
- id = Context.Sample(
+ id = context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
@@ -106,12 +112,11 @@ namespace LLama
}
// Decode this token into text
- decoder.Add(id);
- var decoded = decoder.Read();
+ var decoded = inferenceParams.Tokenizer.Detokenize(context, id);
yield return decoded;
- // Check if any of the antiprompts have been generated
- if (antiprocessor.Add(decoded))
+ // Check if the generation should stop
+ if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded))
break;
lastTokens.Add(id);
@@ -120,19 +125,19 @@ namespace LLama
// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
- if (n_past + tokens.Count >= Context.ContextSize)
+ if (n_past + tokens.Count >= context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;
- NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
- NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
+ NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
+ NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
}
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
- n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
+ n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
}
}
diff --git a/LLama/TextCompletion.cs b/LLama/TextCompletion.cs
new file mode 100644
index 00000000..59b6ba98
--- /dev/null
+++ b/LLama/TextCompletion.cs
@@ -0,0 +1,31 @@
+using LLama.Abstractions;
+using LLama.Common;
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+
+namespace LLama
+{
+ ///
+ /// A class to execute text completion task.
+ ///
+ public class TextCompletion
+ {
+ public string Execute(string prompt, IInferenceParams? inferenceParams = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ public async IAsyncEnumerable StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ throw new NotImplementedException();
+ }
+ }
+}
diff --git a/LLama/Transform/DefaultTokenizer.cs b/LLama/Transform/DefaultTokenizer.cs
new file mode 100644
index 00000000..451c1d85
--- /dev/null
+++ b/LLama/Transform/DefaultTokenizer.cs
@@ -0,0 +1,53 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Transform
+{
+ ///
+ /// The default tokenizer of LLamaSharp. This class should not be inherited.
+ /// Note that this class has state. The previous outputs feeded to it will affect its control.
+ /// If you use it in a session, please don't reuse it for another session unless you intend to do so.
+ ///
+ public sealed class DefaultTokenizer: ITokenizer
+ {
+ private Encoding _encoding;
+ private StreamingTokenDecoder _tokenDecoder;
+
+ ///
+ /// Initialize a new tokenizer with the specified encoding.
+ ///
+ ///
+ public DefaultTokenizer(Encoding encoding)
+ {
+ _encoding = encoding;
+ _tokenDecoder = new StreamingTokenDecoder(encoding);
+ }
+
+ ///
+ ///
+ ///
+ public IEnumerable Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false)
+ {
+ return context.Tokenize(text, addBos, special);
+ }
+
+ ///
+ ///
+ ///
+ public string Detokenize(LLamaContext context, int token)
+ {
+ _tokenDecoder.Add(token, context.NativeHandle.ModelHandle);
+ return _tokenDecoder.Read();
+ }
+
+ ///
+ ///
+ ///
+ public string Detokenize(LLamaContext context, IEnumerable tokens)
+ {
+ _tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle);
+ return _tokenDecoder.Read();
+ }
+ }
+}
diff --git a/LLama/Transform/ITokenizer.cs b/LLama/Transform/ITokenizer.cs
new file mode 100644
index 00000000..3df9fc46
--- /dev/null
+++ b/LLama/Transform/ITokenizer.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Transform
+{
+ public interface ITokenizer
+ {
+ IEnumerable Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false);
+
+ string Detokenize(LLamaContext context, int token);
+
+ string Detokenize(LLamaContext context, IEnumerable tokens);
+ }
+}
diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/Transform/StreamingTokenDecoder.cs
similarity index 90%
rename from LLama/StreamingTokenDecoder.cs
rename to LLama/Transform/StreamingTokenDecoder.cs
index f82f8c37..0653bd4d 100644
--- a/LLama/StreamingTokenDecoder.cs
+++ b/LLama/Transform/StreamingTokenDecoder.cs
@@ -6,14 +6,14 @@ using System.Text;
using LLama.Extensions;
using LLama.Native;
-namespace LLama
+namespace LLama.Transform
{
///
/// Decodes a stream of tokens into a stream of characters
///
public sealed class StreamingTokenDecoder
{
- private readonly SafeLlamaModelHandle _weights;
+ private readonly SafeLlamaModelHandle? _weights;
private readonly Decoder _decoder;
private readonly List _characters = new();
@@ -29,8 +29,8 @@ namespace LLama
///
/// Text encoding to use
/// Model weights
- public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
- : this(encoding, weights.NativeHandle)
+ public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null)
+ : this(encoding, weights?.NativeHandle)
{
}
@@ -69,14 +69,19 @@ namespace LLama
/// Add a single token to the decoder
///
///
- public void Add(int token)
+ public void Add(int token, SafeLlamaModelHandle? weights = null)
{
+ weights ??= _weights;
+ if(weights is null)
+ {
+ throw new NullReferenceException("No weights provided for StreamingTokenDecoder.");
+ }
var charsArr = ArrayPool.Shared.Rent(16);
var bytesArr = ArrayPool.Shared.Rent(16);
try
{
// Convert this token into bytes
- var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
+ var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length;
// Convert those bytes into characters
var bytesOffset = 0;
@@ -133,10 +138,10 @@ namespace LLama
/// Add all tokens in the given enumerable
///
///
- public void AddRange(IEnumerable tokens)
+ public void AddRange(IEnumerable tokens, SafeLlamaModelHandle? weights = null)
{
foreach (var item in tokens)
- Add(item);
+ Add(item, weights);
}
///