diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index dadaf70a..d43cdd79 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion Console.Write("\nQuestion: "); Console.ForegroundColor = ConsoleColor.Green; string prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; + Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); prompt = $"Question: {prompt.Trim()} Answer: "; foreach (var text in ex.Infer(prompt, inferenceParams)) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 1ef2a8db..4fb601d4 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,6 +1,7 @@ using LLama.Exceptions; using LLama.Native; using System; +using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Text; @@ -384,6 +385,7 @@ namespace LLama return candidates_p; } + #region eval overloads /// /// /// @@ -391,7 +393,61 @@ namespace LLama /// /// The updated `pastTokensCount`. /// - public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) + public int Eval(llama_token[] tokens, llama_token pastTokensCount) + { + return Eval(tokens.AsSpan(), pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(List tokens, llama_token pastTokensCount) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(tokens); + return Eval(span, pastTokensCount); +#else + // on netstandard2.0 we can't use collections marshal to get directly at the internal memory of + // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't + // avoid the copying. + + var rented = ArrayPool.Shared.Rent(tokens.Count); + try + { + tokens.CopyTo(rented, 0); + return Eval(rented, pastTokensCount); + } + finally + { + ArrayPool.Shared.Return(rented); + } +#endif + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlyMemory tokens, llama_token pastTokensCount) + { + return Eval(tokens.Span, pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlySpan tokens, llama_token pastTokensCount) { int total = tokens.Length; for(int i = 0; i < total; i += Params.BatchSize) @@ -402,7 +458,7 @@ namespace LLama n_eval = Params.BatchSize; } - if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) + if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) { _logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); @@ -412,6 +468,7 @@ namespace LLama } return pastTokensCount; } +#endregion internal IEnumerable GenerateResult(IEnumerable ids) { @@ -419,6 +476,16 @@ namespace LLama yield return _ctx.TokenToString(id, _encoding); } + /// + /// Convert a token into a string + /// + /// + /// + public string TokenToString(llama_token token) + { + return NativeHandle.TokenToString(token, Encoding); + } + /// public virtual void Dispose() { diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 04663d77..bb5911fe 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -179,12 +179,14 @@ namespace LLama.Native /// the number of tokens to use from previous eval calls /// /// Returns true on success - public bool Eval(Memory tokens, int n_past, int n_threads) + public bool Eval(ReadOnlySpan tokens, int n_past, int n_threads) { - using var pin = tokens.Pin(); unsafe { - return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; + fixed (int* pinned = tokens) + { + return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; + } } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 27eab2c6..45acad76 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -37,7 +37,7 @@ namespace LLama [Obsolete("Use SafeLLamaContextHandle Eval method instead")] public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) { - var slice = tokens.AsMemory().Slice(startIndex, n_tokens); + var slice = tokens.AsSpan().Slice(startIndex, n_tokens); return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; }