| @@ -2,6 +2,9 @@ | |||
| namespace LLama.Abstractions | |||
| { | |||
| /// <summary> | |||
| /// The parameters for initializing a LLama model. | |||
| /// </summary> | |||
| public interface IModelParams | |||
| { | |||
| /// <summary> | |||
| @@ -5,6 +5,7 @@ using System.IO; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| @@ -13,10 +14,12 @@ namespace LLama | |||
| /// </summary> | |||
| public class ChatSession | |||
| { | |||
| private ILLamaExecutor _executor; | |||
| private ChatHistory _history; | |||
| private static readonly string _executorStateFilename = "ExecutorState.json"; | |||
| private static readonly string _modelStateFilename = "ModelState.st"; | |||
| private readonly ILLamaExecutor _executor; | |||
| private readonly ChatHistory _history; | |||
| private const string _executorStateFilename = "ExecutorState.json"; | |||
| private const string _modelStateFilename = "ModelState.st"; | |||
| /// <summary> | |||
| /// The executor for this session. | |||
| /// </summary> | |||
| @@ -227,7 +230,7 @@ namespace LLama | |||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | |||
| await foreach (var item in OutputTransform.TransformAsync(results)) | |||
| await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) | |||
| { | |||
| yield return item; | |||
| } | |||
| @@ -15,9 +15,20 @@ namespace LLama.Common | |||
| private readonly int _maxSize; | |||
| private readonly List<T> _storage; | |||
| /// <summary> | |||
| /// Number of items in this queue | |||
| /// </summary> | |||
| public int Count => _storage.Count; | |||
| /// <summary> | |||
| /// Maximum number of items allowed in this queue | |||
| /// </summary> | |||
| public int Capacity => _maxSize; | |||
| /// <summary> | |||
| /// Create a new queue | |||
| /// </summary> | |||
| /// <param name="size">the maximum number of items to store in this queue</param> | |||
| public FixedSizeQueue(int size) | |||
| { | |||
| _maxSize = size; | |||
| @@ -2,14 +2,16 @@ | |||
| namespace LLama.Exceptions | |||
| { | |||
| public class RuntimeError: Exception | |||
| public class RuntimeError | |||
| : Exception | |||
| { | |||
| public RuntimeError() | |||
| { | |||
| } | |||
| public RuntimeError(string message): base(message) | |||
| public RuntimeError(string message) | |||
| : base(message) | |||
| { | |||
| } | |||
| @@ -1,7 +1,6 @@ | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| @@ -421,7 +420,7 @@ namespace LLama | |||
| // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | |||
| // avoid the copying. | |||
| var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count); | |||
| var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count); | |||
| try | |||
| { | |||
| tokens.CopyTo(rented, 0); | |||
| @@ -429,7 +428,7 @@ namespace LLama | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<llama_token>.Shared.Return(rented); | |||
| System.Buffers.ArrayPool<llama_token>.Shared.Return(rented); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -84,16 +84,16 @@ namespace LLama | |||
| /// <inheritdoc /> | |||
| public override void SaveState(string filename) | |||
| { | |||
| InstructExecutorState state = (InstructExecutorState)GetStateData(); | |||
| using (FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | |||
| var state = (InstructExecutorState)GetStateData(); | |||
| using (var fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | |||
| { | |||
| JsonSerializer.Serialize<InstructExecutorState>(fs, state); | |||
| JsonSerializer.Serialize(fs, state); | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void LoadState(string filename) | |||
| { | |||
| using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||
| { | |||
| var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | |||
| LoadState(state); | |||
| @@ -3,8 +3,16 @@ using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| /// Called by llama.cpp with a progress value between 0 and 1 | |||
| /// </summary> | |||
| /// <param name="progress"></param> | |||
| /// <param name="ctx"></param> | |||
| public delegate void LlamaProgressCallback(float progress, IntPtr ctx); | |||
| /// <summary> | |||
| /// A C# representation of the llama.cpp `llama_context_params` struct | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaContextParams | |||
| { | |||
| @@ -48,7 +56,6 @@ namespace LLama.Native | |||
| /// </summary> | |||
| public nint tensor_split; | |||
| /// <summary> | |||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | |||
| /// RoPE base frequency | |||
| @@ -71,7 +78,6 @@ namespace LLama.Native | |||
| /// </summary> | |||
| public IntPtr progress_callback_user_data; | |||
| /// <summary> | |||
| /// if true, reduce VRAM usage at the cost of performance | |||
| /// </summary> | |||
| @@ -1,5 +1,4 @@ | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using LLama.Exceptions; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -10,6 +10,7 @@ using System.Text; | |||
| using LLama.Common; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -32,7 +33,6 @@ namespace LLama.OldVersion | |||
| bool _is_interacting; | |||
| bool _is_antiprompt; | |||
| bool _input_echo; | |||
| bool _verbose; | |||
| // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | |||
| // if we loaded a session with at least 75% similarity. It's currently just used to speed up the | |||
| @@ -45,17 +45,8 @@ namespace LLama.OldVersion | |||
| List<llama_token> _embed; | |||
| public string Name { get; set; } | |||
| public bool Verbose | |||
| { | |||
| get | |||
| { | |||
| return _verbose; | |||
| } | |||
| set | |||
| { | |||
| _verbose = value; | |||
| } | |||
| } | |||
| public bool Verbose { get; set; } | |||
| public SafeLLamaContextHandle NativeHandle => _ctx; | |||
| /// <summary> | |||
| @@ -178,7 +169,7 @@ namespace LLama.OldVersion | |||
| { | |||
| Name = name; | |||
| _params = @params; | |||
| _verbose = verbose; | |||
| Verbose = verbose; | |||
| _ctx = Utils.llama_init_from_gpt_params(ref _params); | |||
| // Add a space in front of the first character to match OG llama tokenizer behavior | |||
| @@ -514,7 +505,7 @@ namespace LLama.OldVersion | |||
| } | |||
| if (_is_interacting) | |||
| { | |||
| if (_verbose) | |||
| if (Verbose) | |||
| { | |||
| LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it."); | |||
| } | |||
| @@ -625,7 +616,7 @@ namespace LLama.OldVersion | |||
| NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); | |||
| } | |||
| llama_token id = 0; | |||
| llama_token id; | |||
| { | |||
| var n_vocab = NativeApi.llama_n_vocab(_ctx); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -62,7 +63,7 @@ namespace LLama.OldVersion | |||
| public LLamaParams(int seed = 0, int n_threads = -1, int n_predict = -1, | |||
| int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1, | |||
| Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f, | |||
| Dictionary<llama_token, float>? logit_bias = null, int top_k = 40, float top_p = 0.95f, | |||
| float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f, | |||
| int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f, | |||
| int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -8,6 +8,7 @@ using System.Runtime.InteropServices; | |||
| using System.IO; | |||
| #pragma warning disable | |||
| // ReSharper disable all | |||
| namespace LLama.OldVersion | |||
| { | |||
| @@ -56,7 +57,7 @@ namespace LLama.OldVersion | |||
| return res.Take(n).ToList(); | |||
| } | |||
| public unsafe static Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length) | |||
| public static unsafe Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length) | |||
| { | |||
| var logits = NativeApi.llama_get_logits(ctx); | |||
| return new Span<float>(logits, length); | |||
| @@ -67,21 +68,24 @@ namespace LLama.OldVersion | |||
| #if NET6_0_OR_GREATER | |||
| return Marshal.PtrToStringUTF8(ptr); | |||
| #else | |||
| byte* tp = (byte*)ptr.ToPointer(); | |||
| List<byte> bytes = new(); | |||
| while (true) | |||
| unsafe | |||
| { | |||
| byte c = *tp++; | |||
| if (c == '\0') | |||
| byte* tp = (byte*)ptr.ToPointer(); | |||
| List<byte> bytes = new(); | |||
| while (true) | |||
| { | |||
| break; | |||
| } | |||
| else | |||
| { | |||
| bytes.Add(c); | |||
| byte c = *tp++; | |||
| if (c == '\0') | |||
| { | |||
| break; | |||
| } | |||
| else | |||
| { | |||
| bytes.Add(c); | |||
| } | |||
| } | |||
| return Encoding.UTF8.GetString(bytes.ToArray()); | |||
| } | |||
| return Encoding.UTF8.GetString(bytes.ToArray()); | |||
| #endif | |||
| } | |||
| @@ -10,9 +10,15 @@ namespace LLama | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Assorted llama utilities | |||
| /// </summary> | |||
| public static class Utils | |||
| { | |||
| [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")] | |||
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member | |||
| public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) | |||
| #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member | |||
| { | |||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||