Browse Source

- Added various convenience overloads to `LLamaContext.Eval`

- Converted `SafeLLamaContextHandle` to take a `ReadOnlySpan` for Eval, narrower type better represents what's really needed
tags/v0.5.1
Martin Evans 2 years ago
parent
commit
ae8ef17a4a
4 changed files with 76 additions and 7 deletions
  1. +1
    -1
      LLama.Examples/NewVersion/StatelessModeExecute.cs
  2. +69
    -2
      LLama/LLamaContext.cs
  3. +5
    -3
      LLama/Native/SafeLLamaContextHandle.cs
  4. +1
    -1
      LLama/Utils.cs

+ 1
- 1
LLama.Examples/NewVersion/StatelessModeExecute.cs View File

@@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
string prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))


+ 69
- 2
LLama/LLamaContext.cs View File

@@ -1,6 +1,7 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -384,6 +385,7 @@ namespace LLama
return candidates_p;
}

#region eval overloads
/// <summary>
///
/// </summary>
@@ -391,7 +393,61 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
public int Eval(llama_token[] tokens, llama_token pastTokensCount)
{
return Eval(tokens.AsSpan(), pastTokensCount);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
return Eval(span, pastTokensCount);
#else
// on netstandard2.0 we can't use collections marshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
return Eval(rented, pastTokensCount);
}
finally
{
ArrayPool<llama_token>.Shared.Return(rented);
}
#endif
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount)
{
int total = tokens.Length;
for(int i = 0; i < total; i += Params.BatchSize)
@@ -402,7 +458,7 @@ namespace LLama
n_eval = Params.BatchSize;
}

if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads))
{
_logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");
@@ -412,6 +468,7 @@ namespace LLama
}
return pastTokensCount;
}
#endregion

internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
{
@@ -419,6 +476,16 @@ namespace LLama
yield return _ctx.TokenToString(id, _encoding);
}

/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public string TokenToString(llama_token token)
{
return NativeHandle.TokenToString(token, Encoding);
}

/// <inheritdoc />
public virtual void Dispose()
{


+ 5
- 3
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -179,12 +179,14 @@ namespace LLama.Native
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns>
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads)
{
using var pin = tokens.Pin();
unsafe
{
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
fixed (int* pinned = tokens)
{
return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0;
}
}
}
}


+ 1
- 1
LLama/Utils.cs View File

@@ -37,7 +37,7 @@ namespace LLama
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
{
var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}



Loading…
Cancel
Save