| @@ -2,6 +2,9 @@ | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| /// <summary> | |||||
| /// The parameters for initializing a LLama model. | |||||
| /// </summary> | |||||
| public interface IModelParams | public interface IModelParams | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -5,6 +5,7 @@ using System.IO; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -13,10 +14,12 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public class ChatSession | public class ChatSession | ||||
| { | { | ||||
| private ILLamaExecutor _executor; | |||||
| private ChatHistory _history; | |||||
| private static readonly string _executorStateFilename = "ExecutorState.json"; | |||||
| private static readonly string _modelStateFilename = "ModelState.st"; | |||||
| private readonly ILLamaExecutor _executor; | |||||
| private readonly ChatHistory _history; | |||||
| private const string _executorStateFilename = "ExecutorState.json"; | |||||
| private const string _modelStateFilename = "ModelState.st"; | |||||
| /// <summary> | /// <summary> | ||||
| /// The executor for this session. | /// The executor for this session. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -227,7 +230,7 @@ namespace LLama | |||||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | ||||
| await foreach (var item in OutputTransform.TransformAsync(results)) | |||||
| await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) | |||||
| { | { | ||||
| yield return item; | yield return item; | ||||
| } | } | ||||
| @@ -15,9 +15,20 @@ namespace LLama.Common | |||||
| private readonly int _maxSize; | private readonly int _maxSize; | ||||
| private readonly List<T> _storage; | private readonly List<T> _storage; | ||||
| /// <summary> | |||||
| /// Number of items in this queue | |||||
| /// </summary> | |||||
| public int Count => _storage.Count; | public int Count => _storage.Count; | ||||
| /// <summary> | |||||
| /// Maximum number of items allowed in this queue | |||||
| /// </summary> | |||||
| public int Capacity => _maxSize; | public int Capacity => _maxSize; | ||||
| /// <summary> | |||||
| /// Create a new queue | |||||
| /// </summary> | |||||
| /// <param name="size">the maximum number of items to store in this queue</param> | |||||
| public FixedSizeQueue(int size) | public FixedSizeQueue(int size) | ||||
| { | { | ||||
| _maxSize = size; | _maxSize = size; | ||||
| @@ -2,14 +2,16 @@ | |||||
| namespace LLama.Exceptions | namespace LLama.Exceptions | ||||
| { | { | ||||
| public class RuntimeError: Exception | |||||
| public class RuntimeError | |||||
| : Exception | |||||
| { | { | ||||
| public RuntimeError() | public RuntimeError() | ||||
| { | { | ||||
| } | } | ||||
| public RuntimeError(string message): base(message) | |||||
| public RuntimeError(string message) | |||||
| : base(message) | |||||
| { | { | ||||
| } | } | ||||
| @@ -1,7 +1,6 @@ | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -421,7 +420,7 @@ namespace LLama | |||||
| // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | ||||
| // avoid the copying. | // avoid the copying. | ||||
| var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count); | |||||
| var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count); | |||||
| try | try | ||||
| { | { | ||||
| tokens.CopyTo(rented, 0); | tokens.CopyTo(rented, 0); | ||||
| @@ -429,7 +428,7 @@ namespace LLama | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| ArrayPool<llama_token>.Shared.Return(rented); | |||||
| System.Buffers.ArrayPool<llama_token>.Shared.Return(rented); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -84,16 +84,16 @@ namespace LLama | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void SaveState(string filename) | public override void SaveState(string filename) | ||||
| { | { | ||||
| InstructExecutorState state = (InstructExecutorState)GetStateData(); | |||||
| using (FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | |||||
| var state = (InstructExecutorState)GetStateData(); | |||||
| using (var fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | |||||
| { | { | ||||
| JsonSerializer.Serialize<InstructExecutorState>(fs, state); | |||||
| JsonSerializer.Serialize(fs, state); | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void LoadState(string filename) | public override void LoadState(string filename) | ||||
| { | { | ||||
| using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||||
| { | { | ||||
| var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | ||||
| LoadState(state); | LoadState(state); | ||||
| @@ -3,8 +3,16 @@ using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Called by llama.cpp with a progress value between 0 and 1 | |||||
| /// </summary> | |||||
| /// <param name="progress"></param> | |||||
| /// <param name="ctx"></param> | |||||
| public delegate void LlamaProgressCallback(float progress, IntPtr ctx); | public delegate void LlamaProgressCallback(float progress, IntPtr ctx); | ||||
| /// <summary> | |||||
| /// A C# representation of the llama.cpp `llama_context_params` struct | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| public struct LLamaContextParams | public struct LLamaContextParams | ||||
| { | { | ||||
| @@ -48,7 +56,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public nint tensor_split; | public nint tensor_split; | ||||
| /// <summary> | /// <summary> | ||||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | ||||
| /// RoPE base frequency | /// RoPE base frequency | ||||
| @@ -71,7 +78,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public IntPtr progress_callback_user_data; | public IntPtr progress_callback_user_data; | ||||
| /// <summary> | /// <summary> | ||||
| /// if true, reduce VRAM usage at the cost of performance | /// if true, reduce VRAM usage at the cost of performance | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,5 +1,4 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.IO; | using System.IO; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -10,6 +10,7 @@ using System.Text; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -32,7 +33,6 @@ namespace LLama.OldVersion | |||||
| bool _is_interacting; | bool _is_interacting; | ||||
| bool _is_antiprompt; | bool _is_antiprompt; | ||||
| bool _input_echo; | bool _input_echo; | ||||
| bool _verbose; | |||||
| // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | ||||
| // if we loaded a session with at least 75% similarity. It's currently just used to speed up the | // if we loaded a session with at least 75% similarity. It's currently just used to speed up the | ||||
| @@ -45,17 +45,8 @@ namespace LLama.OldVersion | |||||
| List<llama_token> _embed; | List<llama_token> _embed; | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public bool Verbose | |||||
| { | |||||
| get | |||||
| { | |||||
| return _verbose; | |||||
| } | |||||
| set | |||||
| { | |||||
| _verbose = value; | |||||
| } | |||||
| } | |||||
| public bool Verbose { get; set; } | |||||
| public SafeLLamaContextHandle NativeHandle => _ctx; | public SafeLLamaContextHandle NativeHandle => _ctx; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -178,7 +169,7 @@ namespace LLama.OldVersion | |||||
| { | { | ||||
| Name = name; | Name = name; | ||||
| _params = @params; | _params = @params; | ||||
| _verbose = verbose; | |||||
| Verbose = verbose; | |||||
| _ctx = Utils.llama_init_from_gpt_params(ref _params); | _ctx = Utils.llama_init_from_gpt_params(ref _params); | ||||
| // Add a space in front of the first character to match OG llama tokenizer behavior | // Add a space in front of the first character to match OG llama tokenizer behavior | ||||
| @@ -514,7 +505,7 @@ namespace LLama.OldVersion | |||||
| } | } | ||||
| if (_is_interacting) | if (_is_interacting) | ||||
| { | { | ||||
| if (_verbose) | |||||
| if (Verbose) | |||||
| { | { | ||||
| LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it."); | LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it."); | ||||
| } | } | ||||
| @@ -625,7 +616,7 @@ namespace LLama.OldVersion | |||||
| NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); | NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); | ||||
| } | } | ||||
| llama_token id = 0; | |||||
| llama_token id; | |||||
| { | { | ||||
| var n_vocab = NativeApi.llama_n_vocab(_ctx); | var n_vocab = NativeApi.llama_n_vocab(_ctx); | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -62,7 +63,7 @@ namespace LLama.OldVersion | |||||
| public LLamaParams(int seed = 0, int n_threads = -1, int n_predict = -1, | public LLamaParams(int seed = 0, int n_threads = -1, int n_predict = -1, | ||||
| int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1, | int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1, | ||||
| Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f, | |||||
| Dictionary<llama_token, float>? logit_bias = null, int top_k = 40, float top_p = 0.95f, | |||||
| float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f, | float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f, | ||||
| int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f, | int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f, | ||||
| int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, | int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -8,6 +8,7 @@ using System.Runtime.InteropServices; | |||||
| using System.IO; | using System.IO; | ||||
| #pragma warning disable | #pragma warning disable | ||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| @@ -56,7 +57,7 @@ namespace LLama.OldVersion | |||||
| return res.Take(n).ToList(); | return res.Take(n).ToList(); | ||||
| } | } | ||||
| public unsafe static Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length) | |||||
| public static unsafe Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length) | |||||
| { | { | ||||
| var logits = NativeApi.llama_get_logits(ctx); | var logits = NativeApi.llama_get_logits(ctx); | ||||
| return new Span<float>(logits, length); | return new Span<float>(logits, length); | ||||
| @@ -67,21 +68,24 @@ namespace LLama.OldVersion | |||||
| #if NET6_0_OR_GREATER | #if NET6_0_OR_GREATER | ||||
| return Marshal.PtrToStringUTF8(ptr); | return Marshal.PtrToStringUTF8(ptr); | ||||
| #else | #else | ||||
| byte* tp = (byte*)ptr.ToPointer(); | |||||
| List<byte> bytes = new(); | |||||
| while (true) | |||||
| unsafe | |||||
| { | { | ||||
| byte c = *tp++; | |||||
| if (c == '\0') | |||||
| byte* tp = (byte*)ptr.ToPointer(); | |||||
| List<byte> bytes = new(); | |||||
| while (true) | |||||
| { | { | ||||
| break; | |||||
| } | |||||
| else | |||||
| { | |||||
| bytes.Add(c); | |||||
| byte c = *tp++; | |||||
| if (c == '\0') | |||||
| { | |||||
| break; | |||||
| } | |||||
| else | |||||
| { | |||||
| bytes.Add(c); | |||||
| } | |||||
| } | } | ||||
| return Encoding.UTF8.GetString(bytes.ToArray()); | |||||
| } | } | ||||
| return Encoding.UTF8.GetString(bytes.ToArray()); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -10,9 +10,15 @@ namespace LLama | |||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| /// <summary> | |||||
| /// Assorted llama utilities | |||||
| /// </summary> | |||||
| public static class Utils | public static class Utils | ||||
| { | { | ||||
| [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")] | |||||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||||
| public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) | public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) | ||||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||||
| { | { | ||||
| using var weights = LLamaWeights.LoadFromFile(@params); | using var weights = LLamaWeights.LoadFromFile(@params); | ||||