- Integrated grammar into sampling - Added a test for the grammar samplingtags/v0.5.1
| @@ -0,0 +1,62 @@ | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| namespace LLama.Unittest | |||
| { | |||
| public sealed class GrammarTest | |||
| : IDisposable | |||
| { | |||
| private readonly LLamaModel _model = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048)); | |||
| public void Dispose() | |||
| { | |||
| _model.Dispose(); | |||
| } | |||
| [Fact] | |||
| public void CreateBasicGrammar() | |||
| { | |||
| var rules = new List<List<LLamaGrammarElement>> | |||
| { | |||
| new() | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }, | |||
| }; | |||
| using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | |||
| } | |||
| [Fact] | |||
| public void SampleWithTrivialGrammar() | |||
| { | |||
| // Create a grammar that constrains the output to be "one" and nothing else | |||
| var rules = new List<List<LLamaGrammarElement>> | |||
| { | |||
| new() | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'o'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'n'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'e'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }, | |||
| }; | |||
| using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | |||
| var executor = new StatelessExecutor(_model); | |||
| var inferenceParams = new InferenceParams | |||
| { | |||
| MaxTokens = 3, | |||
| AntiPrompts = new [] { ".", "Input:", "\n" }, | |||
| Grammar = grammar, | |||
| }; | |||
| var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); | |||
| Assert.Equal("one", result[0]); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using LLama.Common; | |||
| using LLama.Abstractions; | |||
| using LLama.Native; | |||
| namespace LLama.Web.Common | |||
| { | |||
| @@ -95,5 +96,10 @@ namespace LLama.Web.Common | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| public bool PenalizeNL { get; set; } = true; | |||
| } | |||
| /// <summary> | |||
| /// A grammar to constrain possible tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System.Collections.Generic; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| namespace LLama.Abstractions | |||
| { | |||
| @@ -113,5 +114,10 @@ namespace LLama.Abstractions | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| public bool PenalizeNL { get; set; } | |||
| /// <summary> | |||
| /// Grammar to constrain possible tokens | |||
| /// </summary> | |||
| SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -96,6 +97,11 @@ namespace LLama.Common | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| public bool PenalizeNL { get; set; } = true; | |||
| /// <summary> | |||
| /// A grammar to constrain the possible tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| } | |||
| /// <summary> | |||
| @@ -293,11 +293,19 @@ namespace LLama | |||
| /// <param name="topP"></param> | |||
| /// <param name="tfsZ"></param> | |||
| /// <param name="typicalP"></param> | |||
| /// <param name="grammar"></param> | |||
| /// <returns></returns> | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | |||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) | |||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, | |||
| SafeLLamaGrammarHandle? grammar = null) | |||
| { | |||
| llama_token id; | |||
| if (grammar != null) | |||
| { | |||
| SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); | |||
| } | |||
| if (temperature <= 0) | |||
| { | |||
| // Greedy sampling | |||
| @@ -331,6 +339,12 @@ namespace LLama | |||
| } | |||
| mirostat_mu = mu; | |||
| } | |||
| if (grammar != null) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(_ctx, grammar, id); | |||
| } | |||
| return id; | |||
| } | |||
| @@ -217,7 +217,8 @@ namespace LLama | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||
| inferenceParams.Grammar | |||
| ); | |||
| MirostatMu = mu; | |||
| @@ -206,7 +206,8 @@ namespace LLama | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||
| inferenceParams.Grammar | |||
| ); | |||
| MirostatMu = mu; | |||
| @@ -72,7 +72,7 @@ namespace LLama | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); | |||
| lastTokens.Add(id); | |||
| @@ -33,14 +33,14 @@ namespace LLama.Native | |||
| CHAR_NOT = 4, | |||
| /// <summary> | |||
| /// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to | |||
| /// modifies a preceding CHAR or CHAR_ALT to | |||
| /// be an inclusive range ([a-z]) | |||
| /// </summary> | |||
| CHAR_RNG_UPPER = 5, | |||
| /// <summary> | |||
| /// modifies a preceding LLAMA_GRETYPE_CHAR or | |||
| /// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |||
| /// modifies a preceding CHAR or | |||
| /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |||
| /// </summary> | |||
| CHAR_ALT = 6, | |||
| }; | |||
| @@ -60,5 +60,16 @@ namespace LLama.Native | |||
| /// Unicode code point or rule ID | |||
| /// </summary> | |||
| public uint Value; | |||
| /// <summary> | |||
| /// Construct a new LLamaGrammarElement | |||
| /// </summary> | |||
| /// <param name="type"></param> | |||
| /// <param name="value"></param> | |||
| public LLamaGrammarElement(LLamaGrammarElementType type, uint value) | |||
| { | |||
| Type = type; | |||
| Value = value; | |||
| } | |||
| } | |||
| } | |||
| @@ -7,13 +7,21 @@ namespace LLama.Native | |||
| public unsafe partial class NativeApi | |||
| { | |||
| //todo: LLAMA_API struct llama_grammar * llama_grammar_init(const llama_grammar_element** rules, size_t n_rules,size_t start_rule_index); | |||
| /// <summary> | |||
| /// Free all memory from the given SafeLLamaGrammarHandle | |||
| /// Create a new grammar from the given set of grammar rules | |||
| /// </summary> | |||
| /// <param name="grammar"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| /// <param name="rules"></param> | |||
| /// <param name="n_rules"></param> | |||
| /// <param name="start_rule_index"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); | |||
| /// <summary> | |||
| /// Free all memory from the given SafeLLamaGrammarHandle | |||
| /// </summary> | |||
| /// <param name="grammar"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_grammar_free(IntPtr grammar); | |||
| /// <summary> | |||
| @@ -1,4 +1,9 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using LLama.Exceptions; | |||
| namespace LLama.Native | |||
| { | |||
| @@ -8,6 +13,11 @@ namespace LLama.Native | |||
| public class SafeLLamaGrammarHandle | |||
| : SafeLLamaHandleBase | |||
| { | |||
| #region construction/destruction | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="handle"></param> | |||
| internal SafeLLamaGrammarHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| @@ -20,5 +30,76 @@ namespace LLama.Native | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| /// <summary> | |||
| /// Create a new llama_grammar | |||
| /// </summary> | |||
| /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param> | |||
| /// <param name="start_rule_index">The index (in the outer list) of the start rule</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index) | |||
| { | |||
| unsafe | |||
| { | |||
| var totalElements = rules.Sum(a => a.Count); | |||
| var nrules = (ulong)rules.Count; | |||
| // Borrow an array large enough to hold every single element | |||
| // and another array large enough to hold a pointer to each rule | |||
| var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements); | |||
| var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count); | |||
| try | |||
| { | |||
| fixed (LLamaGrammarElement* allElementsPtr = allElements) | |||
| { | |||
| var elementIndex = 0; | |||
| var pointerIndex = 0; | |||
| foreach (var rule in rules) | |||
| { | |||
| // Save a pointer to the start of this rule | |||
| pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); | |||
| // Copy all of the rule elements into the flat array | |||
| foreach (var element in rule) | |||
| allElementsPtr[elementIndex++] = element; | |||
| } | |||
| // Sanity check some things that should be true if the copy worked as planned | |||
| Debug.Assert((ulong)pointerIndex == nrules); | |||
| Debug.Assert(elementIndex == totalElements); | |||
| // Make the actual call through to llama.cpp | |||
| fixed (void* ptr = pointers) | |||
| { | |||
| return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); | |||
| } | |||
| } | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<LLamaGrammarElement>.Shared.Return(allElements); | |||
| ArrayPool<IntPtr>.Shared.Return(pointers); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Create a new llama_grammar | |||
| /// </summary> | |||
| /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param> | |||
| /// <param name="nrules">total number of rules</param> | |||
| /// <param name="start_rule_index">index of the start rule of the grammar</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) | |||
| { | |||
| var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); | |||
| if (grammar_ptr == IntPtr.Zero) | |||
| throw new RuntimeError("Failed to create grammar from rules"); | |||
| return new(grammar_ptr); | |||
| } | |||
| #endregion | |||
| } | |||
| } | |||
| @@ -5,6 +5,18 @@ namespace LLama.Native | |||
| using llama_token = Int32; | |||
| public unsafe class SamplingApi | |||
| { | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="grammar"></param> | |||
| public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||