Compare commits

...

1 Commits

Author SHA1 Message Date
  Rinne 4f44e3b198
refactor: init some refactorings for experiment. 1 year ago
12 changed files with 237 additions and 33 deletions
Unified View
  1. +12
    -0
      LLama/Abstractions/IInferenceParams.cs
  2. +2
    -2
      LLama/Abstractions/ILLamaExecutor.cs
  3. +9
    -0
      LLama/Common/InferenceParams.cs
  4. +1
    -1
      LLama/Control/AntipromptProcessor.cs
  5. +42
    -0
      LLama/Control/DefaultGenerationControl.cs
  6. +31
    -0
      LLama/Control/IGenerationControl.cs
  7. +1
    -0
      LLama/LLamaExecutorBase.cs
  8. +27
    -22
      LLama/LLamaStatelessExecutor.cs
  9. +31
    -0
      LLama/TextCompletion.cs
  10. +53
    -0
      LLama/Transform/DefaultTokenizer.cs
  11. +15
    -0
      LLama/Transform/ITokenizer.cs
  12. +13
    -8
      LLama/Transform/StreamingTokenDecoder.cs

+ 12
- 0
LLama/Abstractions/IInferenceParams.cs View File

@@ -1,7 +1,9 @@
using System.Collections.Generic; using System.Collections.Generic;
using LLama.Common; using LLama.Common;
using LLama.Control;
using LLama.Native; using LLama.Native;
using LLama.Sampling; using LLama.Sampling;
using LLama.Transform;


namespace LLama.Abstractions namespace LLama.Abstractions
{ {
@@ -114,5 +116,15 @@ namespace LLama.Abstractions
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary> /// </summary>
ISamplingPipeline? SamplingPipeline { get; set; } ISamplingPipeline? SamplingPipeline { get; set; }

/// <summary>
/// Set a custom generation control to use. <b>If this is set antiprompt will be ignored!</b>
/// </summary>
IGenerationControl GenerationControl { get; set; }

/// <summary>
/// Set a custom tokenizer to use.
/// </summary>
ITokenizer Tokenizer { get; set; }
} }
} }

+ 2
- 2
LLama/Abstractions/ILLamaExecutor.cs View File

@@ -18,8 +18,8 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
/// <param name="text">Your prompt</param> /// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param> /// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns></returns> /// <returns></returns>
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default);
} }
} }

+ 9
- 0
LLama/Common/InferenceParams.cs View File

@@ -3,6 +3,9 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using LLama.Native; using LLama.Native;
using LLama.Sampling; using LLama.Sampling;
using LLama.Control;
using LLama.Transform;
using System.Text;


namespace LLama.Common namespace LLama.Common
{ {
@@ -80,6 +83,12 @@ namespace LLama.Common


/// <inheritdoc /> /// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; } public ISamplingPipeline? SamplingPipeline { get; set; }

/// <inheritdoc />
public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl();

/// <inheritdoc />
public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8);
} }


/// <summary> /// <summary>


LLama/AntipromptProcessor.cs → LLama/Control/AntipromptProcessor.cs View File

@@ -1,7 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;


namespace LLama
namespace LLama.Control
{ {
/// <summary> /// <summary>
/// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts /// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts

+ 42
- 0
LLama/Control/DefaultGenerationControl.cs View File

@@ -0,0 +1,42 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited.
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultGenerationControl: IGenerationControl
{
private AntipromptProcessor _antipromptProcessor;

/// <summary>
/// <inheritdoc/>
/// </summary>
public DefaultGenerationControl()
{
_antipromptProcessor = new AntipromptProcessor();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText)
{
_antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts);
return _antipromptProcessor.Add(lastOutputText);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds)
{
return false;
}
}
}

+ 31
- 0
LLama/Control/IGenerationControl.cs View File

@@ -0,0 +1,31 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// Control the text generation of LLama Executors.
/// </summary>
public interface IGenerationControl
{
/// <summary>
/// Use the last output text to determine if the generation should stop.
/// </summary>
/// <param name="context"></param>
/// <param name="inferenceParams"></param>
/// <param name="lastOutputText"></param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText);

/// <summary>
/// Use the last output ids to determine if the generation should stop.
/// </summary>
/// <param name="context"></param>
/// <param name="inferenceParams"></param>
/// <param name="lastOutputIds"></param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds);
}
}

+ 1
- 0
LLama/LLamaExecutorBase.cs View File

@@ -2,6 +2,7 @@
using LLama.Common; using LLama.Common;
using LLama.Exceptions; using LLama.Exceptions;
using LLama.Native; using LLama.Native;
using LLama.Transform;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;


+ 27
- 22
LLama/LLamaStatelessExecutor.cs View File

@@ -8,6 +8,7 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Native; using LLama.Native;
using LLama.Sampling; using LLama.Sampling;
using LLama.Control;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;


namespace LLama namespace LLama
@@ -49,7 +50,7 @@ namespace LLama
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
// Ensure the context from last time is disposed (it always hould be)
// Ensure the context from last time is disposed (it always should be)
if (!Context.NativeHandle.IsClosed) if (!Context.NativeHandle.IsClosed)
Context.Dispose(); Context.Dispose();


@@ -57,48 +58,53 @@ namespace LLama
using var context = _weights.CreateContext(_params, _logger); using var context = _weights.CreateContext(_params, _logger);
Context = context; Context = context;


await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken))
{
yield return item;
}
}

public static async IAsyncEnumerable<string> InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{

// Sanity check inference params // Sanity check inference params
inferenceParams ??= new InferenceParams(); inferenceParams ??= new InferenceParams();
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");

// Create decoders for the token stream
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
if (inferenceParams.TokensKeep > context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})");


// Keep track of the last N tokens emitted // Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<llama_token>(repeat_last_n); var lastTokens = new List<llama_token>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++) for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0); lastTokens.Add(0);


// Tokenize the prompt // Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList();
lastTokens.AddRange(tokens); lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count; var n_past = 1 + tokens.Count;


// Evaluate the prompt // Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false); .ConfigureAwait(false);


// Begin loop, evaluating one token at a time // Begin loop, evaluating one token at a time
var mu = (float?)null; var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{ {
llama_token id; llama_token id;
if (inferenceParams.SamplingPipeline is not null) if (inferenceParams.SamplingPipeline is not null)
{ {
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens);
} }
else else
{ {
// Penalize the generated tokens by various penalties // Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);


// Sample a single token // Sample a single token
id = Context.Sample(
id = context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP inferenceParams.MinP
@@ -106,12 +112,11 @@ namespace LLama
} }


// Decode this token into text // Decode this token into text
decoder.Add(id);
var decoded = decoder.Read();
var decoded = inferenceParams.Tokenizer.Detokenize(context, id);
yield return decoded; yield return decoded;


// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
// Check if the generation should stop
if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded))
break; break;


lastTokens.Add(id); lastTokens.Add(id);
@@ -120,19 +125,19 @@ namespace LLama


// when run out of context // when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
if (n_past + tokens.Count >= context.ContextSize)
{ {
var n_left = n_past - inferenceParams.TokensKeep - 1; var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2; var n_discard = n_left / 2;


NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);


n_past -= n_discard; n_past -= n_discard;
} }


// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false); .ConfigureAwait(false);
} }
} }


+ 31
- 0
LLama/TextCompletion.cs View File

@@ -0,0 +1,31 @@
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;

namespace LLama
{
/// <summary>
/// A class to execute text completion task.
/// </summary>
public class TextCompletion
{
public string Execute(string prompt, IInferenceParams? inferenceParams = null)
{
throw new NotImplementedException();
}

public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null)
{
throw new NotImplementedException();
}

public async IAsyncEnumerable<string> StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
}

+ 53
- 0
LLama/Transform/DefaultTokenizer.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Transform
{
/// <summary>
/// The default tokenizer of LLamaSharp. This class should not be inherited.
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultTokenizer: ITokenizer
{
private Encoding _encoding;
private StreamingTokenDecoder _tokenDecoder;

/// <summary>
/// Initialize a new tokenizer with the specified encoding.
/// </summary>
/// <param name="encoding"></param>
public DefaultTokenizer(Encoding encoding)
{
_encoding = encoding;
_tokenDecoder = new StreamingTokenDecoder(encoding);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false)
{
return context.Tokenize(text, addBos, special);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public string Detokenize(LLamaContext context, int token)
{
_tokenDecoder.Add(token, context.NativeHandle.ModelHandle);
return _tokenDecoder.Read();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public string Detokenize(LLamaContext context, IEnumerable<int> tokens)
{
_tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle);
return _tokenDecoder.Read();
}
}
}

+ 15
- 0
LLama/Transform/ITokenizer.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Transform
{
public interface ITokenizer
{
IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false);

string Detokenize(LLamaContext context, int token);

string Detokenize(LLamaContext context, IEnumerable<int> tokens);
}
}

LLama/StreamingTokenDecoder.cs → LLama/Transform/StreamingTokenDecoder.cs View File

@@ -6,14 +6,14 @@ using System.Text;
using LLama.Extensions; using LLama.Extensions;
using LLama.Native; using LLama.Native;


namespace LLama
namespace LLama.Transform
{ {
/// <summary> /// <summary>
/// Decodes a stream of tokens into a stream of characters /// Decodes a stream of tokens into a stream of characters
/// </summary> /// </summary>
public sealed class StreamingTokenDecoder public sealed class StreamingTokenDecoder
{ {
private readonly SafeLlamaModelHandle _weights;
private readonly SafeLlamaModelHandle? _weights;
private readonly Decoder _decoder; private readonly Decoder _decoder;


private readonly List<char> _characters = new(); private readonly List<char> _characters = new();
@@ -29,8 +29,8 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="encoding">Text encoding to use</param> /// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Model weights</param> /// <param name="weights">Model weights</param>
public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
: this(encoding, weights.NativeHandle)
public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null)
: this(encoding, weights?.NativeHandle)
{ {
} }


@@ -69,14 +69,19 @@ namespace LLama
/// Add a single token to the decoder /// Add a single token to the decoder
/// </summary> /// </summary>
/// <param name="token"></param> /// <param name="token"></param>
public void Add(int token)
public void Add(int token, SafeLlamaModelHandle? weights = null)
{ {
weights ??= _weights;
if(weights is null)
{
throw new NullReferenceException("No weights provided for StreamingTokenDecoder.");
}
var charsArr = ArrayPool<char>.Shared.Rent(16); var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16); var bytesArr = ArrayPool<byte>.Shared.Rent(16);
try try
{ {
// Convert this token into bytes // Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length;


// Convert those bytes into characters // Convert those bytes into characters
var bytesOffset = 0; var bytesOffset = 0;
@@ -133,10 +138,10 @@ namespace LLama
/// Add all tokens in the given enumerable /// Add all tokens in the given enumerable
/// </summary> /// </summary>
/// <param name="tokens"></param> /// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens)
public void AddRange(IEnumerable<int> tokens, SafeLlamaModelHandle? weights = null)
{ {
foreach (var item in tokens) foreach (var item in tokens)
Add(item);
Add(item, weights);
} }


/// <summary> /// <summary>

Loading…
Cancel
Save