|
|
|
@@ -1,6 +1,7 @@ |
|
|
|
using LLama.Exceptions; |
|
|
|
using LLama.Native; |
|
|
|
using System; |
|
|
|
using System.Buffers; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
@@ -384,6 +385,7 @@ namespace LLama |
|
|
|
return candidates_p; |
|
|
|
} |
|
|
|
|
|
|
|
#region eval overloads |
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
@@ -391,7 +393,61 @@ namespace LLama |
|
|
|
/// <param name="pastTokensCount"></param> |
|
|
|
/// <returns>The updated `pastTokensCount`.</returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) |
|
|
|
public int Eval(llama_token[] tokens, llama_token pastTokensCount) |
|
|
|
{ |
|
|
|
return Eval(tokens.AsSpan(), pastTokensCount); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
/// <param name="tokens"></param> |
|
|
|
/// <param name="pastTokensCount"></param> |
|
|
|
/// <returns>The updated `pastTokensCount`.</returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public int Eval(List<llama_token> tokens, llama_token pastTokensCount) |
|
|
|
{ |
|
|
|
#if NET5_0_OR_GREATER |
|
|
|
var span = CollectionsMarshal.AsSpan(tokens); |
|
|
|
return Eval(span, pastTokensCount); |
|
|
|
#else |
|
|
|
// on netstandard2.0 we can't use collections marshal to get directly at the internal memory of |
|
|
|
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't |
|
|
|
// avoid the copying. |
|
|
|
|
|
|
|
var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count); |
|
|
|
try |
|
|
|
{ |
|
|
|
tokens.CopyTo(rented, 0); |
|
|
|
return Eval(rented, pastTokensCount); |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
ArrayPool<llama_token>.Shared.Return(rented); |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
/// <param name="tokens"></param> |
|
|
|
/// <param name="pastTokensCount"></param> |
|
|
|
/// <returns>The updated `pastTokensCount`.</returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount) |
|
|
|
{ |
|
|
|
return Eval(tokens.Span, pastTokensCount); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
/// <param name="tokens"></param> |
|
|
|
/// <param name="pastTokensCount"></param> |
|
|
|
/// <returns>The updated `pastTokensCount`.</returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount) |
|
|
|
{ |
|
|
|
int total = tokens.Length; |
|
|
|
for(int i = 0; i < total; i += Params.BatchSize) |
|
|
|
@@ -402,7 +458,7 @@ namespace LLama |
|
|
|
n_eval = Params.BatchSize; |
|
|
|
} |
|
|
|
|
|
|
|
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) |
|
|
|
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) |
|
|
|
{ |
|
|
|
_logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); |
|
|
|
throw new RuntimeError("Failed to eval."); |
|
|
|
@@ -412,6 +468,7 @@ namespace LLama |
|
|
|
} |
|
|
|
return pastTokensCount; |
|
|
|
} |
|
|
|
#endregion |
|
|
|
|
|
|
|
internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids) |
|
|
|
{ |
|
|
|
@@ -419,6 +476,16 @@ namespace LLama |
|
|
|
yield return _ctx.TokenToString(id, _encoding); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Convert a token into a string |
|
|
|
/// </summary> |
|
|
|
/// <param name="token"></param> |
|
|
|
/// <returns></returns> |
|
|
|
public string TokenToString(llama_token token) |
|
|
|
{ |
|
|
|
return NativeHandle.TokenToString(token, Encoding); |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
public virtual void Dispose() |
|
|
|
{ |
|
|
|
|