Browse Source

Simplified `LLamaInteractExecutor` antiprompt matching by using new extension method

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
77bd090150
4 changed files with 45 additions and 23 deletions
  1. +3
    -0
      LLama/Common/FixedSizeQueue.cs
  2. +39
    -3
      LLama/Extensions/IReadOnlyListExtensions.cs
  3. +3
    -18
      LLama/LLamaInteractExecutor.cs
  4. +0
    -2
      LLama/LLamaStatelessExecutor.cs

+ 3
- 0
LLama/Common/FixedSizeQueue.cs View File

@@ -15,6 +15,8 @@ namespace LLama.Common
private readonly int _maxSize; private readonly int _maxSize;
private readonly List<T> _storage; private readonly List<T> _storage;


internal IReadOnlyList<T> Items => _storage;

/// <summary> /// <summary>
/// Number of items in this queue /// Number of items in this queue
/// </summary> /// </summary>
@@ -57,6 +59,7 @@ namespace LLama.Common
if (_storage.Count > _maxSize) if (_storage.Count > _maxSize)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
} }

/// <summary> /// <summary>
/// Replace every item in the queue with the given value /// Replace every item in the queue with the given value
/// </summary> /// </summary>


+ 39
- 3
LLama/Extensions/IReadOnlyListExtensions.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Buffers; using System.Buffers;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using LLama.Native; using LLama.Native;
@@ -28,10 +29,11 @@ namespace LLama.Extensions
/// <param name="model">Model to use to convert tokens into bytes</param> /// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param> /// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns> /// <returns></returns>
internal static bool TokensEndsWithAnyString<TList>(this TList tokens, IReadOnlyList<string> queries, SafeLlamaModelHandle model, Encoding encoding)
where TList : IReadOnlyList<int>
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{ {
if (queries.Count == 0 || tokens.Count == 0)
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false; return false;


// Find the length of the longest query // Find the length of the longest query
@@ -58,5 +60,39 @@ namespace LLama.Extensions
ArrayPool<char>.Shared.Return(builderArray); ArrayPool<char>.Shared.Return(builderArray);
} }
} }

internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;

return tokens.TokensEndsWithAnyString(new ReadonlyWrapper<string>(queries), model, encoding);
}

private readonly struct ReadonlyWrapper<T>
: IReadOnlyList<T>
{
private readonly IList<T> _list;

public int Count => _list.Count;

public T this[int index] => _list[index];

public ReadonlyWrapper(IList<T> list)
{
_list = list;
}

public IEnumerator<T> GetEnumerator()
{
return _list.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_list).GetEnumerator();
}
}
} }
} }

+ 3
- 18
LLama/LLamaInteractExecutor.cs View File

@@ -8,6 +8,7 @@ using System.Linq;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Text; using System.Text;
using LLama.Extensions;


namespace LLama namespace LLama
{ {
@@ -128,27 +129,11 @@ namespace LLama
extraOutputs = null; extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount) if (_embed_inps.Count <= _consumedTokensCount)
{ {
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
var last_output_builder = new StringBuilder();
foreach (var token in _last_n_tokens)
Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
var last_output = last_output_builder.ToString();

foreach (var antiprompt in args.Antiprompts)
{
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
break;
}
}
}
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;


if (_pastTokensCount > 0 && args.WaitForInput) if (_pastTokensCount > 0 && args.WaitForInput)
{
return true; return true;
}
} }


if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))


+ 0
- 2
LLama/LLamaStatelessExecutor.cs View File

@@ -1,11 +1,9 @@
using LLama.Abstractions; using LLama.Abstractions;
using LLama.Common; using LLama.Common;
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text;
using System.Threading; using System.Threading;
using LLama.Extensions; using LLama.Extensions;




Loading…
Cancel
Save