Browse Source

refactor: init some refactorings for experiment.

refactor_v1.0
Rinne 1 year ago
parent
commit
4f44e3b198
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
12 changed files with 237 additions and 33 deletions
  1. +12
    -0
      LLama/Abstractions/IInferenceParams.cs
  2. +2
    -2
      LLama/Abstractions/ILLamaExecutor.cs
  3. +9
    -0
      LLama/Common/InferenceParams.cs
  4. +1
    -1
      LLama/Control/AntipromptProcessor.cs
  5. +42
    -0
      LLama/Control/DefaultGenerationControl.cs
  6. +31
    -0
      LLama/Control/IGenerationControl.cs
  7. +1
    -0
      LLama/LLamaExecutorBase.cs
  8. +27
    -22
      LLama/LLamaStatelessExecutor.cs
  9. +31
    -0
      LLama/TextCompletion.cs
  10. +53
    -0
      LLama/Transform/DefaultTokenizer.cs
  11. +15
    -0
      LLama/Transform/ITokenizer.cs
  12. +13
    -8
      LLama/Transform/StreamingTokenDecoder.cs

+ 12
- 0
LLama/Abstractions/IInferenceParams.cs View File

@@ -1,7 +1,9 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Control;
using LLama.Native;
using LLama.Sampling;
using LLama.Transform;

namespace LLama.Abstractions
{
@@ -114,5 +116,15 @@ namespace LLama.Abstractions
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }

/// <summary>
/// Set a custom generation control to use. <b>If this is set antiprompt will be ignored!</b>
/// </summary>
IGenerationControl GenerationControl { get; set; }

/// <summary>
/// Set a custom tokenizer to use.
/// </summary>
ITokenizer Tokenizer { get; set; }
}
}

+ 2
- 2
LLama/Abstractions/ILLamaExecutor.cs View File

@@ -18,8 +18,8 @@ namespace LLama.Abstractions
/// </summary>
/// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default);
}
}

+ 9
- 0
LLama/Common/InferenceParams.cs View File

@@ -3,6 +3,9 @@ using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;
using LLama.Transform;
using System.Text;

namespace LLama.Common
{
@@ -80,6 +83,12 @@ namespace LLama.Common

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }

/// <inheritdoc />
public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl();

/// <inheritdoc />
public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8);
}

/// <summary>


LLama/AntipromptProcessor.cs → LLama/Control/AntipromptProcessor.cs View File

@@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;

namespace LLama
namespace LLama.Control
{
/// <summary>
/// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts

+ 42
- 0
LLama/Control/DefaultGenerationControl.cs View File

@@ -0,0 +1,42 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited.
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultGenerationControl: IGenerationControl
{
private AntipromptProcessor _antipromptProcessor;

/// <summary>
/// <inheritdoc/>
/// </summary>
public DefaultGenerationControl()
{
_antipromptProcessor = new AntipromptProcessor();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText)
{
_antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts);
return _antipromptProcessor.Add(lastOutputText);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds)
{
return false;
}
}
}

+ 31
- 0
LLama/Control/IGenerationControl.cs View File

@@ -0,0 +1,31 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// Control the text generation of LLama Executors.
/// </summary>
public interface IGenerationControl
{
/// <summary>
/// Use the last output text to determine if the generation should stop.
/// </summary>
/// <param name="context"></param>
/// <param name="inferenceParams"></param>
/// <param name="lastOutputText"></param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText);

/// <summary>
/// Use the last output ids to determine if the generation should stop.
/// </summary>
/// <param name="context"></param>
/// <param name="inferenceParams"></param>
/// <param name="lastOutputIds"></param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds);
}
}

+ 1
- 0
LLama/LLamaExecutorBase.cs View File

@@ -2,6 +2,7 @@
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
using LLama.Transform;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;


+ 27
- 22
LLama/LLamaStatelessExecutor.cs View File

@@ -8,6 +8,7 @@ using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -49,7 +50,7 @@ namespace LLama
/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Ensure the context from last time is disposed (it always hould be)
// Ensure the context from last time is disposed (it always should be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();

@@ -57,48 +58,53 @@ namespace LLama
using var context = _weights.CreateContext(_params, _logger);
Context = context;

await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken))
{
yield return item;
}
}

public static async IAsyncEnumerable<string> InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{

// Sanity check inference params
inferenceParams ??= new InferenceParams();
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");

// Create decoders for the token stream
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
if (inferenceParams.TokensKeep > context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})");

// Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<llama_token>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);

// Begin loop, evaluating one token at a time
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
id = Context.Sample(
id = context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
@@ -106,12 +112,11 @@ namespace LLama
}

// Decode this token into text
decoder.Add(id);
var decoded = decoder.Read();
var decoded = inferenceParams.Tokenizer.Detokenize(context, id);
yield return decoded;

// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
// Check if the generation should stop
if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded))
break;

lastTokens.Add(id);
@@ -120,19 +125,19 @@ namespace LLama

// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
if (n_past + tokens.Count >= context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;

NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;
}

// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
}
}


+ 31
- 0
LLama/TextCompletion.cs View File

@@ -0,0 +1,31 @@
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;

namespace LLama
{
/// <summary>
/// A class to execute text completion task.
/// </summary>
public class TextCompletion
{
public string Execute(string prompt, IInferenceParams? inferenceParams = null)
{
throw new NotImplementedException();
}

public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null)
{
throw new NotImplementedException();
}

public async IAsyncEnumerable<string> StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
}

+ 53
- 0
LLama/Transform/DefaultTokenizer.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Transform
{
/// <summary>
/// The default tokenizer of LLamaSharp. This class should not be inherited.
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultTokenizer: ITokenizer
{
private Encoding _encoding;
private StreamingTokenDecoder _tokenDecoder;

/// <summary>
/// Initialize a new tokenizer with the specified encoding.
/// </summary>
/// <param name="encoding"></param>
public DefaultTokenizer(Encoding encoding)
{
_encoding = encoding;
_tokenDecoder = new StreamingTokenDecoder(encoding);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false)
{
return context.Tokenize(text, addBos, special);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public string Detokenize(LLamaContext context, int token)
{
_tokenDecoder.Add(token, context.NativeHandle.ModelHandle);
return _tokenDecoder.Read();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public string Detokenize(LLamaContext context, IEnumerable<int> tokens)
{
_tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle);
return _tokenDecoder.Read();
}
}
}

+ 15
- 0
LLama/Transform/ITokenizer.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Transform
{
public interface ITokenizer
{
IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false);

string Detokenize(LLamaContext context, int token);

string Detokenize(LLamaContext context, IEnumerable<int> tokens);
}
}

LLama/StreamingTokenDecoder.cs → LLama/Transform/StreamingTokenDecoder.cs View File

@@ -6,14 +6,14 @@ using System.Text;
using LLama.Extensions;
using LLama.Native;

namespace LLama
namespace LLama.Transform
{
/// <summary>
/// Decodes a stream of tokens into a stream of characters
/// </summary>
public sealed class StreamingTokenDecoder
{
private readonly SafeLlamaModelHandle _weights;
private readonly SafeLlamaModelHandle? _weights;
private readonly Decoder _decoder;

private readonly List<char> _characters = new();
@@ -29,8 +29,8 @@ namespace LLama
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Model weights</param>
public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
: this(encoding, weights.NativeHandle)
public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null)
: this(encoding, weights?.NativeHandle)
{
}

@@ -69,14 +69,19 @@ namespace LLama
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
public void Add(int token, SafeLlamaModelHandle? weights = null)
{
weights ??= _weights;
if(weights is null)
{
throw new NullReferenceException("No weights provided for StreamingTokenDecoder.");
}
var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16);
try
{
// Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length;

// Convert those bytes into characters
var bytesOffset = 0;
@@ -133,10 +138,10 @@ namespace LLama
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens)
public void AddRange(IEnumerable<int> tokens, SafeLlamaModelHandle? weights = null)
{
foreach (var item in tokens)
Add(item);
Add(item, weights);
}

/// <summary>

Loading…
Cancel
Save