* Updated binaries to llama.cpp `3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6` (build run: https://github.com/SciSharp/LLamaSharp/actions/runs/8118890586) * Added abort callback * Added properties to get/set thread count on `LLamaContext` * Fixed LLamaLogLevel numberingtags/0.11.0
| @@ -103,5 +103,11 @@ namespace LLama.Web.Common | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public bool VocabOnly { get; set; } | public bool VocabOnly { get; set; } | ||||
| /// <inheritdoc /> | |||||
| public float DefragThreshold { get; set; } | |||||
| /// <inheritdoc /> | |||||
| public bool DoPooling { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -98,4 +98,14 @@ public interface IContextParams | |||||
| /// Whether to disable offloading the KQV cache to the GPU | /// Whether to disable offloading the KQV cache to the GPU | ||||
| /// </summary> | /// </summary> | ||||
| bool NoKqvOffload { get; } | bool NoKqvOffload { get; } | ||||
| /// <summary> | |||||
| /// defragment the KV cache if holes/size > defrag_threshold, Set to < 0 to disable (default) | |||||
| /// </summary> | |||||
| float DefragThreshold { get; } | |||||
| /// <summary> | |||||
| /// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) | |||||
| /// </summary> | |||||
| bool DoPooling { get; } | |||||
| } | } | ||||
| @@ -251,7 +251,7 @@ namespace LLama.Abstractions | |||||
| { | { | ||||
| Key = key; | Key = key; | ||||
| _valueInt = value; | _valueInt = value; | ||||
| Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; | |||||
| Type = LLamaModelKvOverrideType.Int; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -263,7 +263,7 @@ namespace LLama.Abstractions | |||||
| { | { | ||||
| Key = key; | Key = key; | ||||
| _valueFloat = value; | _valueFloat = value; | ||||
| Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; | |||||
| Type = LLamaModelKvOverrideType.Float; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -275,20 +275,20 @@ namespace LLama.Abstractions | |||||
| { | { | ||||
| Key = key; | Key = key; | ||||
| _valueBool = value; | _valueBool = value; | ||||
| Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; | |||||
| Type = LLamaModelKvOverrideType.Bool; | |||||
| } | } | ||||
| internal void WriteValue(ref LLamaModelMetadataOverride dest) | internal void WriteValue(ref LLamaModelMetadataOverride dest) | ||||
| { | { | ||||
| switch (Type) | switch (Type) | ||||
| { | { | ||||
| case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT: | |||||
| case LLamaModelKvOverrideType.Int: | |||||
| dest.IntValue = _valueInt; | dest.IntValue = _valueInt; | ||||
| break; | break; | ||||
| case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT: | |||||
| case LLamaModelKvOverrideType.Float: | |||||
| dest.FloatValue = _valueFloat; | dest.FloatValue = _valueFloat; | ||||
| break; | break; | ||||
| case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL: | |||||
| case LLamaModelKvOverrideType.Bool: | |||||
| dest.BoolValue = _valueBool ? -1L : 0; | dest.BoolValue = _valueBool ? -1L : 0; | ||||
| break; | break; | ||||
| default: | default: | ||||
| @@ -300,13 +300,13 @@ namespace LLama.Abstractions | |||||
| { | { | ||||
| switch (Type) | switch (Type) | ||||
| { | { | ||||
| case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT: | |||||
| case LLamaModelKvOverrideType.Int: | |||||
| writer.WriteNumberValue(_valueInt); | writer.WriteNumberValue(_valueInt); | ||||
| break; | break; | ||||
| case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT: | |||||
| case LLamaModelKvOverrideType.Float: | |||||
| writer.WriteNumberValue(_valueFloat); | writer.WriteNumberValue(_valueFloat); | ||||
| break; | break; | ||||
| case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL: | |||||
| case LLamaModelKvOverrideType.Bool: | |||||
| writer.WriteBooleanValue(_valueBool); | writer.WriteBooleanValue(_valueBool); | ||||
| break; | break; | ||||
| default: | default: | ||||
| @@ -328,9 +328,9 @@ namespace LLama.Abstractions | |||||
| return ((LLamaModelKvOverrideType)ktv.Type) switch | return ((LLamaModelKvOverrideType)ktv.Type) switch | ||||
| { | { | ||||
| LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()), | |||||
| LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()), | |||||
| LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()), | |||||
| LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()), | |||||
| LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()), | |||||
| LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()), | |||||
| _ => throw new JsonException(), | _ => throw new JsonException(), | ||||
| }; | }; | ||||
| } | } | ||||
| @@ -262,9 +262,9 @@ public sealed class Conversation | |||||
| /// <param name="start">Start position (inclusive)</param> | /// <param name="start">Start position (inclusive)</param> | ||||
| /// <param name="end">End position (exclusive)</param> | /// <param name="end">End position (exclusive)</param> | ||||
| /// <param name="delta">Amount to add on to each token position</param> | /// <param name="delta">Amount to add on to each token position</param> | ||||
| public void Shift(LLamaPos start, LLamaPos end, int delta) | |||||
| public void Add(LLamaPos start, LLamaPos end, int delta) | |||||
| { | { | ||||
| _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta); | |||||
| } | } | ||||
| #endregion | #endregion | ||||
| @@ -50,7 +50,7 @@ public static class ConversationExtensions | |||||
| kv.Remove(keep, count); | kv.Remove(keep, count); | ||||
| // Shift the C's | // Shift the C's | ||||
| kv.Shift(keep + count, end, -count); | |||||
| kv.Add(keep + count, end, -count); | |||||
| // Update total count | // Update total count | ||||
| return end.Value - count; | return end.Value - count; | ||||
| @@ -93,6 +93,12 @@ namespace LLama.Common | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public bool NoKqvOffload { get; set; } | public bool NoKqvOffload { get; set; } | ||||
| /// <inheritdoc /> | |||||
| public float DefragThreshold { get; set; } | |||||
| /// <inheritdoc /> | |||||
| public bool DoPooling { get; set; } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public bool VocabOnly { get; set; } | public bool VocabOnly { get; set; } | ||||
| @@ -34,7 +34,9 @@ namespace LLama.Extensions | |||||
| result.yarn_beta_fast = @params.YarnBetaFast ?? 32f; | result.yarn_beta_fast = @params.YarnBetaFast ?? 32f; | ||||
| result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f; | result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f; | ||||
| result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0; | result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0; | ||||
| result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED; | |||||
| result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified; | |||||
| result.defrag_threshold = @params.DefragThreshold; | |||||
| result.cb_eval = IntPtr.Zero; | result.cb_eval = IntPtr.Zero; | ||||
| result.cb_eval_user_data = IntPtr.Zero; | result.cb_eval_user_data = IntPtr.Zero; | ||||
| @@ -42,6 +44,7 @@ namespace LLama.Extensions | |||||
| result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16; | result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16; | ||||
| result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16; | result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16; | ||||
| result.offload_kqv = !@params.NoKqvOffload; | result.offload_kqv = !@params.NoKqvOffload; | ||||
| result.do_pooling = @params.DoPooling; | |||||
| result.n_threads = Threads(@params.Threads); | result.n_threads = Threads(@params.Threads); | ||||
| result.n_threads_batch = Threads(@params.BatchThreads); | result.n_threads_batch = Threads(@params.BatchThreads); | ||||
| @@ -56,6 +56,35 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public Encoding Encoding { get; } | public Encoding Encoding { get; } | ||||
| private uint _generationThreads; | |||||
| private uint _batchThreads; | |||||
| /// <summary> | |||||
| /// Get or set the number of threads to use for generation | |||||
| /// </summary> | |||||
| public uint GenerationThreads | |||||
| { | |||||
| get => _generationThreads; | |||||
| set | |||||
| { | |||||
| _generationThreads = value; | |||||
| NativeHandle.SetThreads(_generationThreads, _batchThreads); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Get or set the number of threads to use for batch processing | |||||
| /// </summary> | |||||
| public uint BatchThreads | |||||
| { | |||||
| get => _batchThreads; | |||||
| set | |||||
| { | |||||
| _batchThreads = value; | |||||
| NativeHandle.SetThreads(_generationThreads, _batchThreads); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new LLamaContext for the given LLamaWeights | /// Create a new LLamaContext for the given LLamaWeights | ||||
| /// </summary> | /// </summary> | ||||
| @@ -75,6 +104,10 @@ namespace LLama | |||||
| @params.ToLlamaContextParams(out var lparams); | @params.ToLlamaContextParams(out var lparams); | ||||
| NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | ||||
| // It's not possible to get these values from llama.cpp, store a copy of them here. | |||||
| _generationThreads = lparams.n_threads; | |||||
| _batchThreads = lparams.n_threads_batch; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -59,7 +59,7 @@ namespace LLama | |||||
| private static bool ValidateFtype(LLamaFtype ftype) | private static bool ValidateFtype(LLamaFtype ftype) | ||||
| { | { | ||||
| // Validation copies from here: | // Validation copies from here: | ||||
| // https://github.com/ggerganov/llama.cpp/blob/d71ac90985854b0905e1abba778e407e17f9f887/llama.cpp#L9613 | |||||
| // https://github.com/ggerganov/llama.cpp/blob/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6/llama.cpp#L10965 | |||||
| switch (ftype) | switch (ftype) | ||||
| { | { | ||||
| @@ -74,7 +74,7 @@ namespace LLama | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K_S: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K_S: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_XS: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_K_XS: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L: | ||||
| @@ -89,8 +89,18 @@ namespace LLama | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XXS: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XXS: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XS: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XS: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_S: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_M: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS: | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_S: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_M: | |||||
| return true; | return true; | ||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: | case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: | ||||
| @@ -134,7 +134,7 @@ namespace LLama | |||||
| var n_discard = n_left / 2; | var n_discard = n_left / 2; | ||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | ||||
| NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| n_past -= n_discard; | n_past -= n_discard; | ||||
| } | } | ||||
| @@ -0,0 +1,11 @@ | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <remarks>llama_chat_message</remarks> | |||||
| public unsafe struct LLamaChatMessage | |||||
| { | |||||
| public byte* role; | |||||
| public byte* content; | |||||
| } | |||||
| @@ -51,32 +51,42 @@ namespace LLama.Native | |||||
| /// RoPE base frequency, 0 = from model | /// RoPE base frequency, 0 = from model | ||||
| /// </summary> | /// </summary> | ||||
| public float rope_freq_base; | public float rope_freq_base; | ||||
| /// <summary> | /// <summary> | ||||
| /// RoPE frequency scaling factor, 0 = from model | /// RoPE frequency scaling factor, 0 = from model | ||||
| /// </summary> | /// </summary> | ||||
| public float rope_freq_scale; | |||||
| public float rope_freq_scale; | |||||
| /// <summary> | /// <summary> | ||||
| /// YaRN extrapolation mix factor, negative = from model | /// YaRN extrapolation mix factor, negative = from model | ||||
| /// </summary> | /// </summary> | ||||
| public float yarn_ext_factor; | |||||
| public float yarn_ext_factor; | |||||
| /// <summary> | /// <summary> | ||||
| /// YaRN magnitude scaling factor | /// YaRN magnitude scaling factor | ||||
| /// </summary> | /// </summary> | ||||
| public float yarn_attn_factor; | |||||
| public float yarn_attn_factor; | |||||
| /// <summary> | /// <summary> | ||||
| /// YaRN low correction dim | /// YaRN low correction dim | ||||
| /// </summary> | /// </summary> | ||||
| public float yarn_beta_fast; | |||||
| public float yarn_beta_fast; | |||||
| /// <summary> | /// <summary> | ||||
| /// YaRN high correction dim | /// YaRN high correction dim | ||||
| /// </summary> | /// </summary> | ||||
| public float yarn_beta_slow; | |||||
| public float yarn_beta_slow; | |||||
| /// <summary> | /// <summary> | ||||
| /// YaRN original context size | /// YaRN original context size | ||||
| /// </summary> | /// </summary> | ||||
| public uint yarn_orig_ctx; | public uint yarn_orig_ctx; | ||||
| /// <summary> | |||||
| /// defragment the KV cache if holes/size > defrag_threshold, Set to < 0 to disable (default) | |||||
| /// </summary> | |||||
| public float defrag_threshold; | |||||
| /// <summary> | /// <summary> | ||||
| /// ggml_backend_sched_eval_callback | /// ggml_backend_sched_eval_callback | ||||
| /// </summary> | /// </summary> | ||||
| @@ -97,11 +107,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public GGMLType type_v; | public GGMLType type_v; | ||||
| /// <summary> | |||||
| /// Deprecated! | |||||
| /// </summary> | |||||
| private sbyte _mul_mat_q; | |||||
| /// <summary> | /// <summary> | ||||
| /// Deprecated! | /// Deprecated! | ||||
| /// </summary> | /// </summary> | ||||
| @@ -126,6 +131,16 @@ namespace LLama.Native | |||||
| set => _offload_kqv = Convert.ToSByte(value); | set => _offload_kqv = Convert.ToSByte(value); | ||||
| } | } | ||||
| private sbyte _offload_kqv; | private sbyte _offload_kqv; | ||||
| /// <summary> | |||||
| /// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) | |||||
| /// </summary> | |||||
| public bool do_pooling | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_do_pooling); | |||||
| set => _do_pooling = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _do_pooling; | |||||
| } | } | ||||
| } | } | ||||
| @@ -124,13 +124,48 @@ | |||||
| /// <summary> | /// <summary> | ||||
| /// except 1d tensors | /// except 1d tensors | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, | |||||
| LLAMA_FTYPE_MOSTLY_IQ3_K_XS = 22, | |||||
| /// <summary> | /// <summary> | ||||
| /// except 1d tensors | /// except 1d tensors | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, | LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, | ||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ1_S = 24, | |||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, | |||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ3_S = 26, | |||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ3_M = 27, | |||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ2_S = 28, | |||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ2_M = 29, | |||||
| /// <summary> | |||||
| /// except 1d tensors | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, | |||||
| /// <summary> | /// <summary> | ||||
| /// File type was not specified | /// File type was not specified | ||||
| /// </summary> | /// </summary> | ||||
| @@ -5,11 +5,6 @@ | |||||
| /// </summary> | /// </summary> | ||||
| public enum LLamaLogLevel | public enum LLamaLogLevel | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Logs that are used for interactive investigation during development. | |||||
| /// </summary> | |||||
| Debug = 1, | |||||
| /// <summary> | /// <summary> | ||||
| /// Logs that highlight when the current flow of execution is stopped due to a failure. | /// Logs that highlight when the current flow of execution is stopped due to a failure. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -23,6 +18,11 @@ | |||||
| /// <summary> | /// <summary> | ||||
| /// Logs that track the general flow of the application. | /// Logs that track the general flow of the application. | ||||
| /// </summary> | /// </summary> | ||||
| Info = 4 | |||||
| Info = 4, | |||||
| /// <summary> | |||||
| /// Logs that are used for interactive investigation during development. | |||||
| /// </summary> | |||||
| Debug = 5, | |||||
| } | } | ||||
| } | } | ||||
| @@ -48,20 +48,21 @@ public unsafe struct LLamaModelMetadataOverride | |||||
| /// <summary> | /// <summary> | ||||
| /// Specifies what type of value is being overridden by LLamaModelKvOverride | /// Specifies what type of value is being overridden by LLamaModelKvOverride | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>llama_model_kv_override_type</remarks> | |||||
| public enum LLamaModelKvOverrideType | public enum LLamaModelKvOverrideType | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Overriding an int value | /// Overriding an int value | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_KV_OVERRIDE_INT = 0, | |||||
| Int = 0, | |||||
| /// <summary> | /// <summary> | ||||
| /// Overriding a float value | /// Overriding a float value | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_KV_OVERRIDE_FLOAT = 1, | |||||
| Float = 1, | |||||
| /// <summary> | /// <summary> | ||||
| /// Overriding a bool value | /// Overriding a bool value | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_KV_OVERRIDE_BOOL = 2, | |||||
| Bool = 2, | |||||
| } | } | ||||
| @@ -0,0 +1,12 @@ | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <remarks>llama_pooling_type</remarks> | |||||
| public enum LLamaPoolingType | |||||
| { | |||||
| None = 0, | |||||
| Mean = 1, | |||||
| CLS = 2, | |||||
| } | |||||
| @@ -0,0 +1,9 @@ | |||||
| namespace LLama.Native; | |||||
| public enum LLamaRopeType | |||||
| { | |||||
| None = -1, | |||||
| Norm = 0, | |||||
| NEOX = 2, | |||||
| GLM = 4, | |||||
| } | |||||
| @@ -0,0 +1,12 @@ | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <remarks>llama_vocab_type</remarks> | |||||
| public enum LLamaVocabType | |||||
| { | |||||
| SentencePiece = 0, | |||||
| BytePairEncoding = 1, | |||||
| WordPiece = 2, | |||||
| } | |||||
| @@ -30,7 +30,7 @@ namespace LLama.Native | |||||
| "4. Try to compile llama.cpp yourself to generate a libllama library, then use `LLama.Native.NativeLibraryConfig.WithLibrary` " + | "4. Try to compile llama.cpp yourself to generate a libllama library, then use `LLama.Native.NativeLibraryConfig.WithLibrary` " + | ||||
| "to specify it at the very beginning of your code. For more informations about compilation, please refer to LLamaSharp repo on github.\n"); | "to specify it at the very beginning of your code. For more informations about compilation, please refer to LLamaSharp repo on github.\n"); | ||||
| } | } | ||||
| llama_backend_init(false); | |||||
| llama_backend_init(); | |||||
| } | } | ||||
| private static void Log(string message, LogLevel level) | private static void Log(string message, LogLevel level) | ||||
| @@ -80,7 +80,16 @@ namespace LLama.Native | |||||
| /// Call once at the start of the program | /// Call once at the start of the program | ||||
| /// </summary> | /// </summary> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| private static extern void llama_backend_init(bool numa); | |||||
| private static extern void llama_backend_init(); | |||||
| // Note: this is not implemented because we don't have a definition for `ggml_numa_strategy` in C#. That definition doesn't | |||||
| // exist because it's not in llama.h, it's in ggml.h which we don't currently build a wrapper for. If there's demand | |||||
| // for better NUMA support that will need adding. | |||||
| ///// <summary> | |||||
| ///// Optional, enable NUMA optimisations | |||||
| ///// </summary> | |||||
| //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| //public static extern void llama_numa_init(ggml_numa_strategy numa); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets the current rng seed. | /// Sets the current rng seed. | ||||
| @@ -187,6 +196,13 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); | public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); | ||||
| /// <summary> | |||||
| /// Get the embeddings for the ith sequence. Equivalent to: llama_get_embeddings(ctx) + i*n_embd | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the embeddings for the input | /// Get the embeddings for the input | ||||
| /// </summary> | /// </summary> | ||||
| @@ -204,6 +220,22 @@ namespace LLama.Native | |||||
| static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx); | static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Apply chat template. Inspired by hf apply_chat_template() on python. | |||||
| /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" | |||||
| /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="tmpl">A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.</param> | |||||
| /// <param name="chat">Pointer to a list of multiple llama_chat_message</param> | |||||
| /// <param name="n_msg">Number of llama_chat_message in this chat</param> | |||||
| /// <param name="add_ass">Whether to end the prompt with the token(s) that indicate the start of an assistant message.</param> | |||||
| /// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param> | |||||
| /// <param name="length">The size of the allocated buffer</param> | |||||
| /// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] | |||||
| public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nint n_msg, bool add_ass, char* buf, int length); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the "Beginning of sentence" token | /// Get the "Beginning of sentence" token | ||||
| /// </summary> | /// </summary> | ||||
| @@ -371,7 +403,9 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) | /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) | ||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly | |||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly: | |||||
| /// - lazily on next llama_decode() | |||||
| /// - explicitly with llama_kv_cache_update() | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="seq"></param> | /// <param name="seq"></param> | ||||
| @@ -379,12 +413,16 @@ namespace LLama.Native | |||||
| /// <param name="p1"></param> | /// <param name="p1"></param> | ||||
| /// <param name="delta"></param> | /// <param name="delta"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); | |||||
| public static extern void llama_kv_cache_seq_add(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); | |||||
| /// <summary> | /// <summary> | ||||
| /// Integer division of the positions by factor of `d > 1` | /// Integer division of the positions by factor of `d > 1` | ||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly | |||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly: | |||||
| /// - lazily on next llama_decode() | |||||
| /// - explicitly with llama_kv_cache_update() | |||||
| /// <br /> | |||||
| /// p0 < 0 : [0, p1] | /// p0 < 0 : [0, p1] | ||||
| /// <br /> | |||||
| /// p1 < 0 : [p0, inf) | /// p1 < 0 : [p0, inf) | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| @@ -395,6 +433,32 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_kv_cache_seq_div(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int d); | public static extern void llama_kv_cache_seq_div(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int d); | ||||
| /// <summary> | |||||
| /// Returns the largest position present in the KV cache for the specified sequence | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="seq"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaPos llama_kv_cache_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq); | |||||
| /// <summary> | |||||
| /// Defragment the KV cache. This will be applied: | |||||
| /// - lazily on next llama_decode() | |||||
| /// - explicitly with llama_kv_cache_update() | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaPos llama_kv_cache_defrag(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | |||||
| /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Allocates a batch of tokens on the heap | /// Allocates a batch of tokens on the heap | ||||
| /// Each token can be assigned up to n_seq_max sequence ids | /// Each token can be assigned up to n_seq_max sequence ids | ||||
| @@ -438,5 +502,11 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); | public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern LLamaRopeType llama_rope_type(SafeLlamaModelHandle model); | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,26 +5,26 @@ | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>C# equivalent of llama_rope_scaling_type</remarks> | /// <remarks>C# equivalent of llama_rope_scaling_type</remarks> | ||||
| public enum RopeScalingType | public enum RopeScalingType | ||||
| : sbyte | |||||
| : int | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// No particular scaling type has been specified | /// No particular scaling type has been specified | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_ROPE_SCALING_UNSPECIFIED = -1, | |||||
| Unspecified = -1, | |||||
| /// <summary> | /// <summary> | ||||
| /// Do not apply any RoPE scaling | /// Do not apply any RoPE scaling | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_ROPE_SCALING_NONE = 0, | |||||
| None = 0, | |||||
| /// <summary> | /// <summary> | ||||
| /// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k | /// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_ROPE_SCALING_LINEAR = 1, | |||||
| Linear = 1, | |||||
| /// <summary> | /// <summary> | ||||
| /// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf | /// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf | ||||
| /// </summary> | /// </summary> | ||||
| LLAMA_ROPE_SCALING_YARN = 2, | |||||
| Yarn = 2, | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,7 +2,6 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| namespace LLama.Native | namespace LLama.Native | ||||
| @@ -112,8 +111,24 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| private static extern void llama_free(IntPtr ctx); | private static extern void llama_free(IntPtr ctx); | ||||
| #endregion | |||||
| /// <summary> | |||||
| /// Set a callback which can abort computation | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="abort_callback"></param> | |||||
| /// <param name="abort_callback_data"></param> | |||||
| [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| private static extern unsafe void llama_set_abort_callback(SafeLLamaContextHandle ctx, GgmlAbortCallback abort_callback, void* abort_callback_data); | |||||
| /// <summary> | |||||
| /// If this returns true computation is cancelled | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <returns></returns> | |||||
| private unsafe delegate bool GgmlAbortCallback(void* data); | |||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// Token logits obtained from the last call to llama_decode | /// Token logits obtained from the last call to llama_decode | ||||
| /// The logits for the last token are stored in the last row | /// The logits for the last token are stored in the last row | ||||
| @@ -390,9 +405,9 @@ namespace LLama.Native | |||||
| /// <param name="p0"></param> | /// <param name="p0"></param> | ||||
| /// <param name="p1"></param> | /// <param name="p1"></param> | ||||
| /// <param name="delta"></param> | /// <param name="delta"></param> | ||||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) | |||||
| public void KvCacheSequenceAdd(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) | |||||
| { | { | ||||
| NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); | |||||
| NativeApi.llama_kv_cache_seq_add(this, seq, p0, p1, delta); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||