- Made `NativeApi` into a `static class` (it's not intended to be instantiated) - Moved `LLamaTokenType` enum out into a separate file - Made `LLamaSeqId` and `LLamaPos` into `record struct`, convenient to have equality etctags/0.9.1
| @@ -214,7 +214,7 @@ namespace LLama.Abstractions | |||
| /// <summary> | |||
| /// Get the key being overriden by this override | |||
| /// </summary> | |||
| public string Key { get; init; } | |||
| public string Key { get; } | |||
| internal LLamaModelKvOverrideType Type { get; } | |||
| @@ -1,5 +1,4 @@ | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| @@ -37,6 +36,7 @@ namespace LLama.Common | |||
| /// </summary> | |||
| public class ChatHistory | |||
| { | |||
| private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; | |||
| /// <summary> | |||
| /// Chat message representation | |||
| @@ -96,12 +96,7 @@ namespace LLama.Common | |||
| /// <returns></returns> | |||
| public string ToJson() | |||
| { | |||
| return JsonSerializer.Serialize( | |||
| this, | |||
| new JsonSerializerOptions() | |||
| { | |||
| WriteIndented = true | |||
| }); | |||
| return JsonSerializer.Serialize(this, _jsonOptions); | |||
| } | |||
| /// <summary> | |||
| @@ -2,7 +2,6 @@ | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using LLama.Extensions; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -18,11 +18,13 @@ namespace LLama.Common | |||
| /// number of tokens to keep from initial prompt | |||
| /// </summary> | |||
| public int TokensKeep { get; set; } = 0; | |||
| /// <summary> | |||
| /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response | |||
| /// until it complete. | |||
| /// </summary> | |||
| public int MaxTokens { get; set; } = -1; | |||
| /// <summary> | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| @@ -15,6 +15,7 @@ namespace LLama.Extensions | |||
| internal static TValue GetValueOrDefaultImpl<TKey, TValue>(IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue) | |||
| { | |||
| // ReSharper disable once CanSimplifyDictionaryTryGetValueWithGetValueOrDefault (this is a shim for that method!) | |||
| return dictionary.TryGetValue(key, out var value) ? value : defaultValue; | |||
| } | |||
| } | |||
| @@ -15,7 +15,7 @@ namespace LLama.Grammars | |||
| /// <summary> | |||
| /// Index of the initial rule to start from | |||
| /// </summary> | |||
| public ulong StartRuleIndex { get; set; } | |||
| public ulong StartRuleIndex { get; } | |||
| /// <summary> | |||
| /// The rules which make up this grammar | |||
| @@ -121,6 +121,12 @@ namespace LLama.Grammars | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| break; | |||
| case LLamaGrammarElementType.END: | |||
| case LLamaGrammarElementType.ALT: | |||
| case LLamaGrammarElementType.RULE_REF: | |||
| case LLamaGrammarElementType.CHAR: | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| default: | |||
| output.Append("] "); | |||
| break; | |||
| @@ -43,7 +43,7 @@ namespace LLama | |||
| /// <summary> | |||
| /// The context params set for this context | |||
| /// </summary> | |||
| public IContextParams Params { get; set; } | |||
| public IContextParams Params { get; } | |||
| /// <summary> | |||
| /// The native handle, which is used to be passed to the native APIs | |||
| @@ -56,15 +56,6 @@ namespace LLama | |||
| /// </summary> | |||
| public Encoding Encoding { get; } | |||
| internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) | |||
| { | |||
| Params = @params; | |||
| _logger = logger; | |||
| Encoding = @params.Encoding; | |||
| NativeHandle = nativeContext; | |||
| } | |||
| /// <summary> | |||
| /// Create a new LLamaContext for the given LLamaWeights | |||
| /// </summary> | |||
| @@ -12,17 +12,15 @@ namespace LLama | |||
| public sealed class LLamaEmbedder | |||
| : IDisposable | |||
| { | |||
| private readonly LLamaContext _ctx; | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| /// </summary> | |||
| public int EmbeddingSize => _ctx.EmbeddingSize; | |||
| public int EmbeddingSize => Context.EmbeddingSize; | |||
| /// <summary> | |||
| /// LLama Context | |||
| /// </summary> | |||
| public LLamaContext Context => this._ctx; | |||
| public LLamaContext Context { get; } | |||
| /// <summary> | |||
| /// Create a new embedder, using the given LLamaWeights | |||
| @@ -33,7 +31,7 @@ namespace LLama | |||
| public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) | |||
| { | |||
| @params.EmbeddingMode = true; | |||
| _ctx = weights.CreateContext(@params, logger); | |||
| Context = weights.CreateContext(@params, logger); | |||
| } | |||
| /// <summary> | |||
| @@ -72,20 +70,20 @@ namespace LLama | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public float[] GetEmbeddings(string text, bool addBos) | |||
| { | |||
| var embed_inp_array = _ctx.Tokenize(text, addBos); | |||
| var embed_inp_array = Context.Tokenize(text, addBos); | |||
| // TODO(Rinne): deal with log of prompt | |||
| if (embed_inp_array.Length > 0) | |||
| _ctx.Eval(embed_inp_array, 0); | |||
| Context.Eval(embed_inp_array, 0); | |||
| unsafe | |||
| { | |||
| var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle); | |||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||
| if (embeddings == null) | |||
| return Array.Empty<float>(); | |||
| return new Span<float>(embeddings, EmbeddingSize).ToArray(); | |||
| return embeddings.ToArray(); | |||
| } | |||
| } | |||
| @@ -94,7 +92,7 @@ namespace LLama | |||
| /// </summary> | |||
| public void Dispose() | |||
| { | |||
| _ctx.Dispose(); | |||
| Context.Dispose(); | |||
| } | |||
| } | |||
| @@ -64,7 +64,7 @@ namespace LLama | |||
| /// </summary> | |||
| public IReadOnlyDictionary<string, string> Metadata { get; set; } | |||
| internal LLamaWeights(SafeLlamaModelHandle weights) | |||
| private LLamaWeights(SafeLlamaModelHandle weights) | |||
| { | |||
| NativeHandle = weights; | |||
| Metadata = weights.ReadMetadata(); | |||
| @@ -14,7 +14,7 @@ public struct LLamaKvCacheViewCell | |||
| /// May be negative if the cell is not populated. | |||
| /// </summary> | |||
| public LLamaPos pos; | |||
| }; | |||
| } | |||
| /// <summary> | |||
| /// An updateable view of the KV cache (llama_kv_cache_view) | |||
| @@ -130,7 +130,7 @@ public class LLamaKvCacheViewSafeHandle | |||
| } | |||
| } | |||
| partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// Create an empty KV cache view. (use only for debugging purposes) | |||
| @@ -6,7 +6,7 @@ namespace LLama.Native; | |||
| /// Indicates position in a sequence | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaPos | |||
| public record struct LLamaPos | |||
| { | |||
| /// <summary> | |||
| /// The raw value | |||
| @@ -17,7 +17,7 @@ public struct LLamaPos | |||
| /// Create a new LLamaPos | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| public LLamaPos(int value) | |||
| private LLamaPos(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| @@ -6,7 +6,7 @@ namespace LLama.Native; | |||
| /// ID for a sequence in a batch | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaSeqId | |||
| public record struct LLamaSeqId | |||
| { | |||
| /// <summary> | |||
| /// The raw value | |||
| @@ -17,7 +17,7 @@ public struct LLamaSeqId | |||
| /// Create a new LLamaSeqId | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| public LLamaSeqId(int value) | |||
| private LLamaSeqId(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| namespace LLama.Native; | |||
| public enum LLamaTokenType | |||
| { | |||
| LLAMA_TOKEN_TYPE_UNDEFINED = 0, | |||
| LLAMA_TOKEN_TYPE_NORMAL = 1, | |||
| LLAMA_TOKEN_TYPE_UNKNOWN = 2, | |||
| LLAMA_TOKEN_TYPE_CONTROL = 3, | |||
| LLAMA_TOKEN_TYPE_USER_DEFINED = 4, | |||
| LLAMA_TOKEN_TYPE_UNUSED = 5, | |||
| LLAMA_TOKEN_TYPE_BYTE = 6, | |||
| } | |||
| @@ -3,7 +3,7 @@ using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| public partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// Type of pointer to the beam_search_callback function. | |||
| @@ -5,7 +5,7 @@ namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| public unsafe partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// Create a new grammar from the given set of grammar rules | |||
| @@ -15,7 +15,7 @@ namespace LLama.Native | |||
| /// <param name="start_rule_index"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); | |||
| public static extern unsafe IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); | |||
| /// <summary> | |||
| /// Free all memory from the given SafeLLamaGrammarHandle | |||
| @@ -4,13 +4,12 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text.Json; | |||
| namespace LLama.Native | |||
| { | |||
| public partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| static NativeApi() | |||
| { | |||
| @@ -97,22 +96,13 @@ namespace LLama.Native | |||
| } | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| return -1; | |||
| } | |||
| else | |||
| { | |||
| version = version.Split('.')[0]; | |||
| bool success = int.TryParse(version, out var majorVersion); | |||
| if (success) | |||
| { | |||
| return majorVersion; | |||
| } | |||
| else | |||
| { | |||
| return -1; | |||
| } | |||
| } | |||
| version = version.Split('.')[0]; | |||
| if (int.TryParse(version, out var majorVersion)) | |||
| return majorVersion; | |||
| return -1; | |||
| } | |||
| private static string GetCudaVersionFromPath(string cudaPath) | |||
| @@ -129,7 +119,7 @@ namespace LLama.Native | |||
| { | |||
| return string.Empty; | |||
| } | |||
| return versionNode.GetString(); | |||
| return versionNode.GetString() ?? ""; | |||
| } | |||
| } | |||
| catch (Exception) | |||
| @@ -169,18 +159,14 @@ namespace LLama.Native | |||
| { | |||
| platform = OSPlatform.OSX; | |||
| suffix = ".dylib"; | |||
| if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported) | |||
| { | |||
| prefix = "runtimes/osx-arm64/native/"; | |||
| } | |||
| else | |||
| { | |||
| prefix = "runtimes/osx-x64/native/"; | |||
| } | |||
| prefix = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported | |||
| ? "runtimes/osx-arm64/native/" | |||
| : "runtimes/osx-x64/native/"; | |||
| } | |||
| else | |||
| { | |||
| throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp."); | |||
| throw new RuntimeError("Your system plarform is not supported, please open an issue in LLamaSharp."); | |||
| } | |||
| Log($"Detected OS Platform: {platform}", LogLevel.Information); | |||
| @@ -275,15 +261,15 @@ namespace LLama.Native | |||
| var libraryTryLoadOrder = GetLibraryTryOrder(configuration); | |||
| string[] preferredPaths = configuration.SearchDirectories; | |||
| string[] possiblePathPrefix = new string[] { | |||
| System.AppDomain.CurrentDomain.BaseDirectory, | |||
| var preferredPaths = configuration.SearchDirectories; | |||
| var possiblePathPrefix = new[] { | |||
| AppDomain.CurrentDomain.BaseDirectory, | |||
| Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" | |||
| }; | |||
| var tryFindPath = (string filename) => | |||
| string TryFindPath(string filename) | |||
| { | |||
| foreach(var path in preferredPaths) | |||
| foreach (var path in preferredPaths) | |||
| { | |||
| if (File.Exists(Path.Combine(path, filename))) | |||
| { | |||
| @@ -291,7 +277,7 @@ namespace LLama.Native | |||
| } | |||
| } | |||
| foreach(var path in possiblePathPrefix) | |||
| foreach (var path in possiblePathPrefix) | |||
| { | |||
| if (File.Exists(Path.Combine(path, filename))) | |||
| { | |||
| @@ -300,21 +286,19 @@ namespace LLama.Native | |||
| } | |||
| return filename; | |||
| }; | |||
| } | |||
| foreach (var libraryPath in libraryTryLoadOrder) | |||
| { | |||
| var fullPath = tryFindPath(libraryPath); | |||
| var fullPath = TryFindPath(libraryPath); | |||
| var result = TryLoad(fullPath, true); | |||
| if (result is not null && result != IntPtr.Zero) | |||
| { | |||
| Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information); | |||
| return result ?? IntPtr.Zero; | |||
| } | |||
| else | |||
| { | |||
| Log($"Tried to load {fullPath} but failed.", LogLevel.Information); | |||
| return (IntPtr)result; | |||
| } | |||
| Log($"Tried to load {fullPath} but failed.", LogLevel.Information); | |||
| } | |||
| if (!configuration.AllowFallback) | |||
| @@ -2,7 +2,7 @@ | |||
| namespace LLama.Native | |||
| { | |||
| public partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// Returns 0 on success | |||
| @@ -5,7 +5,7 @@ namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| public unsafe partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| @@ -19,7 +19,7 @@ namespace LLama.Native | |||
| /// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||
| /// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, | |||
| public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, | |||
| ref LLamaTokenDataArrayNative candidates, | |||
| llama_token* last_tokens, ulong last_tokens_size, | |||
| float penalty_repeat, | |||
| @@ -9,17 +9,6 @@ namespace LLama.Native | |||
| { | |||
| using llama_token = Int32; | |||
| public enum LLamaTokenType | |||
| { | |||
| LLAMA_TOKEN_TYPE_UNDEFINED = 0, | |||
| LLAMA_TOKEN_TYPE_NORMAL = 1, | |||
| LLAMA_TOKEN_TYPE_UNKNOWN = 2, | |||
| LLAMA_TOKEN_TYPE_CONTROL = 3, | |||
| LLAMA_TOKEN_TYPE_USER_DEFINED = 4, | |||
| LLAMA_TOKEN_TYPE_UNUSED = 5, | |||
| LLAMA_TOKEN_TYPE_BYTE = 6, | |||
| } | |||
| /// <summary> | |||
| /// Callback from llama.cpp with log messages | |||
| /// </summary> | |||
| @@ -30,7 +19,7 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Direct translation of the llama.cpp API | |||
| /// </summary> | |||
| public unsafe partial class NativeApi | |||
| public static partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. | |||
| @@ -165,7 +154,7 @@ namespace LLama.Native | |||
| /// <param name="dest"></param> | |||
| /// <returns>the number of bytes copied</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); | |||
| public static extern unsafe ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); | |||
| /// <summary> | |||
| /// Set the state reading from the specified address | |||
| @@ -174,7 +163,7 @@ namespace LLama.Native | |||
| /// <param name="src"></param> | |||
| /// <returns>the number of bytes read</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); | |||
| public static extern unsafe ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); | |||
| /// <summary> | |||
| /// Load session file | |||
| @@ -186,7 +175,7 @@ namespace LLama.Native | |||
| /// <param name="n_token_count_out"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); | |||
| public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); | |||
| /// <summary> | |||
| /// Save session file | |||
| @@ -211,7 +200,7 @@ namespace LLama.Native | |||
| /// <returns>Returns 0 on success</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| [Obsolete("use llama_decode() instead")] | |||
| public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); | |||
| public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); | |||
| /// <summary> | |||
| /// Convert the provided text into tokens. | |||
| @@ -228,34 +217,37 @@ namespace LLama.Native | |||
| /// </returns> | |||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) | |||
| { | |||
| // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) | |||
| var byteCount = encoding.GetByteCount(text); | |||
| var array = ArrayPool<byte>.Shared.Rent(byteCount + 1); | |||
| try | |||
| unsafe | |||
| { | |||
| // Convert to bytes | |||
| fixed (char* textPtr = text) | |||
| fixed (byte* arrayPtr = array) | |||
| // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) | |||
| var byteCount = encoding.GetByteCount(text); | |||
| var array = ArrayPool<byte>.Shared.Rent(byteCount + 1); | |||
| try | |||
| { | |||
| encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); | |||
| // Convert to bytes | |||
| fixed (char* textPtr = text) | |||
| fixed (byte* arrayPtr = array) | |||
| { | |||
| encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); | |||
| } | |||
| // Add a zero byte to the end to terminate the string | |||
| array[byteCount] = 0; | |||
| // Do the actual tokenization | |||
| fixed (byte* arrayPtr = array) | |||
| fixed (llama_token* tokensPtr = tokens) | |||
| return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<byte>.Shared.Return(array); | |||
| } | |||
| // Add a zero byte to the end to terminate the string | |||
| array[byteCount] = 0; | |||
| // Do the actual tokenization | |||
| fixed (byte* arrayPtr = array) | |||
| fixed (llama_token* tokensPtr = tokens) | |||
| return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<byte>.Shared.Return(array); | |||
| } | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); | |||
| public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token); | |||
| @@ -281,7 +273,7 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern float* llama_get_logits(SafeLLamaContextHandle ctx); | |||
| public static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx); | |||
| /// <summary> | |||
| /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab | |||
| @@ -290,16 +282,24 @@ namespace LLama.Native | |||
| /// <param name="i"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); | |||
| public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); | |||
| /// <summary> | |||
| /// Get the embeddings for the input | |||
| /// shape: [n_embd] (1-dimensional) | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); | |||
| public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx) | |||
| { | |||
| unsafe | |||
| { | |||
| var ptr = llama_get_embeddings_native(ctx); | |||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] | |||
| static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx); | |||
| } | |||
| /// <summary> | |||
| /// Get the "Beginning of sentence" token | |||
| @@ -426,7 +426,7 @@ namespace LLama.Native | |||
| /// <param name="buf_size"></param> | |||
| /// <returns>The length of the string on success, or -1 on failure</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); | |||
| public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); | |||
| /// <summary> | |||
| /// Get the number of metadata key/value pairs | |||
| @@ -445,7 +445,7 @@ namespace LLama.Native | |||
| /// <param name="buf_size"></param> | |||
| /// <returns>The length of the string on success, or -1 on failure</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); | |||
| public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); | |||
| /// <summary> | |||
| /// Get metadata value as a string by index | |||
| @@ -456,7 +456,7 @@ namespace LLama.Native | |||
| /// <param name="buf_size"></param> | |||
| /// <returns>The length of the string on success, or -1 on failure</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); | |||
| public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); | |||
| /// <summary> | |||
| /// Get a string describing the model type | |||
| @@ -466,7 +466,7 @@ namespace LLama.Native | |||
| /// <param name="buf_size"></param> | |||
| /// <returns>The length of the string on success, or -1 on failure</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); | |||
| public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); | |||
| /// <summary> | |||
| /// Get the size of the model in bytes | |||
| @@ -493,7 +493,7 @@ namespace LLama.Native | |||
| /// <param name="length">size of the buffer</param> | |||
| /// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||
| public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||
| /// <summary> | |||
| /// Convert text into tokens | |||
| @@ -509,7 +509,7 @@ namespace LLama.Native | |||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||
| /// </returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); | |||
| public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); | |||
| /// <summary> | |||
| /// Register a callback to receive llama log messages | |||
| @@ -29,10 +29,11 @@ namespace LLama.Native | |||
| private bool _allowFallback = true; | |||
| private bool _skipCheck = false; | |||
| private bool _logging = false; | |||
| /// <summary> | |||
| /// search directory -> priority level, 0 is the lowest. | |||
| /// </summary> | |||
| private List<string> _searchDirectories = new List<string>(); | |||
| private readonly List<string> _searchDirectories = new List<string>(); | |||
| private static void ThrowIfLoaded() | |||
| { | |||
| @@ -159,9 +160,8 @@ namespace LLama.Native | |||
| internal static Description CheckAndGatherDescription() | |||
| { | |||
| if (Instance._allowFallback && Instance._skipCheck) | |||
| { | |||
| throw new ArgumentException("Cannot skip the check when fallback is allowed."); | |||
| } | |||
| return new Description( | |||
| Instance._libraryPath, | |||
| Instance._useCuda, | |||
| @@ -169,7 +169,8 @@ namespace LLama.Native | |||
| Instance._allowFallback, | |||
| Instance._skipCheck, | |||
| Instance._logging, | |||
| Instance._searchDirectories.Concat(new string[] { "./" }).ToArray()); | |||
| Instance._searchDirectories.Concat(new[] { "./" }).ToArray() | |||
| ); | |||
| } | |||
| internal static string AvxLevelToString(AvxLevel level) | |||
| @@ -204,7 +205,9 @@ namespace LLama.Native | |||
| if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported) | |||
| return false; | |||
| // ReSharper disable UnusedVariable (ebx is used when < NET8) | |||
| var (_, ebx, ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0); | |||
| // ReSharper restore UnusedVariable | |||
| var vnni = (ecx & 0b_1000_0000_0000) != 0; | |||
| @@ -1,6 +1,5 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| @@ -51,8 +50,6 @@ namespace LLama.Native | |||
| _model.DangerousAddRef(ref success); | |||
| if (!success) | |||
| throw new RuntimeError("Failed to increment model refcount"); | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -214,7 +214,6 @@ namespace LLama.Native | |||
| /// Get the metadata key for the given index | |||
| /// </summary> | |||
| /// <param name="index">The index to get</param> | |||
| /// <param name="buffer">A temporary buffer to store key characters in. Must be large enough to contain the key.</param> | |||
| /// <returns>The key, null if there is no such key or if the buffer was too small</returns> | |||
| public Memory<byte>? MetadataKeyByIndex(int index) | |||
| { | |||
| @@ -243,7 +242,6 @@ namespace LLama.Native | |||
| /// Get the metadata value for the given index | |||
| /// </summary> | |||
| /// <param name="index">The index to get</param> | |||
| /// <param name="buffer">A temporary buffer to store value characters in. Must be large enough to contain the value.</param> | |||
| /// <returns>The value, null if there is no such value or if the buffer was too small</returns> | |||
| public Memory<byte>? MetadataValueByIndex(int index) | |||
| { | |||