Browse Source

Simplified `LLamaInteractExecutor` antiprompt matching by using new extension method

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
77bd090150
4 changed files with 45 additions and 23 deletions
  1. +3
    -0
      LLama/Common/FixedSizeQueue.cs
  2. +39
    -3
      LLama/Extensions/IReadOnlyListExtensions.cs
  3. +3
    -18
      LLama/LLamaInteractExecutor.cs
  4. +0
    -2
      LLama/LLamaStatelessExecutor.cs

+ 3
- 0
LLama/Common/FixedSizeQueue.cs View File

@@ -15,6 +15,8 @@ namespace LLama.Common
private readonly int _maxSize;
private readonly List<T> _storage;

internal IReadOnlyList<T> Items => _storage;

/// <summary>
/// Number of items in this queue
/// </summary>
@@ -57,6 +59,7 @@ namespace LLama.Common
if (_storage.Count > _maxSize)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}

/// <summary>
/// Replace every item in the queue with the given value
/// </summary>


+ 39
- 3
LLama/Extensions/IReadOnlyListExtensions.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using LLama.Native;
@@ -28,10 +29,11 @@ namespace LLama.Extensions
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
internal static bool TokensEndsWithAnyString<TList>(this TList tokens, IReadOnlyList<string> queries, SafeLlamaModelHandle model, Encoding encoding)
where TList : IReadOnlyList<int>
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{
if (queries.Count == 0 || tokens.Count == 0)
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;

// Find the length of the longest query
@@ -58,5 +60,39 @@ namespace LLama.Extensions
ArrayPool<char>.Shared.Return(builderArray);
}
}

internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;

return tokens.TokensEndsWithAnyString(new ReadonlyWrapper<string>(queries), model, encoding);
}

private readonly struct ReadonlyWrapper<T>
: IReadOnlyList<T>
{
private readonly IList<T> _list;

public int Count => _list.Count;

public T this[int index] => _list[index];

public ReadonlyWrapper(IList<T> list)
{
_list = list;
}

public IEnumerator<T> GetEnumerator()
{
return _list.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_list).GetEnumerator();
}
}
}
}

+ 3
- 18
LLama/LLamaInteractExecutor.cs View File

@@ -8,6 +8,7 @@ using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text;
using LLama.Extensions;

namespace LLama
{
@@ -128,27 +129,11 @@ namespace LLama
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
var last_output_builder = new StringBuilder();
foreach (var token in _last_n_tokens)
Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
var last_output = last_output_builder.ToString();

foreach (var antiprompt in args.Antiprompts)
{
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
break;
}
}
}
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
{
return true;
}
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))


+ 0
- 2
LLama/LLamaStatelessExecutor.cs View File

@@ -1,11 +1,9 @@
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using LLama.Extensions;



Loading…
Cancel
Save