diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs new file mode 100644 index 00000000..c813a573 --- /dev/null +++ b/LLama.Unittest/GrammarTest.cs @@ -0,0 +1,62 @@ +using LLama.Common; +using LLama.Native; + +namespace LLama.Unittest +{ + public sealed class GrammarTest + : IDisposable + { + private readonly LLamaModel _model = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048)); + + public void Dispose() + { + _model.Dispose(); + } + + [Fact] + public void CreateBasicGrammar() + { + var rules = new List> + { + new() + { + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }, + }; + + using var handle = SafeLLamaGrammarHandle.Create(rules, 0); + } + + [Fact] + public void SampleWithTrivialGrammar() + { + // Create a grammar that constrains the output to be "one" and nothing else + var rules = new List> + { + new() + { + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'o'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'n'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'e'), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }, + }; + + using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); + + var executor = new StatelessExecutor(_model); + var inferenceParams = new InferenceParams + { + MaxTokens = 3, + AntiPrompts = new [] { ".", "Input:", "\n" }, + Grammar = grammar, + }; + + var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); + + Assert.Equal("one", result[0]); + } + } +} diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs index 7677f04a..f78aa861 100644 --- a/LLama.Web/Common/ParameterOptions.cs +++ b/LLama.Web/Common/ParameterOptions.cs @@ -1,5 +1,6 @@ using LLama.Common; using LLama.Abstractions; +using LLama.Native; namespace LLama.Web.Common { @@ -95,5 +96,10 @@ namespace LLama.Web.Common /// consider newlines as a repeatable token (penalize_nl) /// public bool PenalizeNL { get; set; } = true; - } + + /// + /// A grammar to constrain possible tokens + /// + public SafeLLamaGrammarHandle Grammar { get; set; } = null; + } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index 73cbbfd2..e576366f 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using LLama.Common; +using LLama.Native; namespace LLama.Abstractions { @@ -113,5 +114,10 @@ namespace LLama.Abstractions /// consider newlines as a repeatable token (penalize_nl) /// public bool PenalizeNL { get; set; } + + /// + /// Grammar to constrain possible tokens + /// + SafeLLamaGrammarHandle? Grammar { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 001a8f8e..64d2652b 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,6 +1,7 @@ using LLama.Abstractions; using System; using System.Collections.Generic; +using LLama.Native; namespace LLama.Common { @@ -96,6 +97,11 @@ namespace LLama.Common /// consider newlines as a repeatable token (penalize_nl) /// public bool PenalizeNL { get; set; } = true; + + /// + /// A grammar to constrain the possible tokens + /// + public SafeLLamaGrammarHandle? Grammar { get; set; } } /// diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 1ef2a8db..9c053d37 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -293,11 +293,19 @@ namespace LLama /// /// /// + /// /// public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, - float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) + float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, + SafeLLamaGrammarHandle? grammar = null) { llama_token id; + + if (grammar != null) + { + SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); + } + if (temperature <= 0) { // Greedy sampling @@ -331,6 +339,12 @@ namespace LLama } mirostat_mu = mu; } + + if (grammar != null) + { + NativeApi.llama_grammar_accept_token(_ctx, grammar, id); + } + return id; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 6773cdde..de708785 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -217,7 +217,8 @@ namespace LLama var mu = MirostatMu; var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, + inferenceParams.Grammar ); MirostatMu = mu; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 533a1863..e65c6f19 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -206,7 +206,8 @@ namespace LLama var mu = MirostatMu; var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, + inferenceParams.Grammar ); MirostatMu = mu; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index f09ff7dd..c2fe4985 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -72,7 +72,7 @@ namespace LLama inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index 6764b285..d097628f 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -33,14 +33,14 @@ namespace LLama.Native CHAR_NOT = 4, /// - /// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + /// modifies a preceding CHAR or CHAR_ALT to /// be an inclusive range ([a-z]) /// CHAR_RNG_UPPER = 5, /// - /// modifies a preceding LLAMA_GRETYPE_CHAR or - /// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + /// modifies a preceding CHAR or + /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) /// CHAR_ALT = 6, }; @@ -60,5 +60,16 @@ namespace LLama.Native /// Unicode code point or rule ID /// public uint Value; + + /// + /// Construct a new LLamaGrammarElement + /// + /// + /// + public LLamaGrammarElement(LLamaGrammarElementType type, uint value) + { + Type = type; + Value = value; + } } } diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index ef36756e..354ade3b 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -7,13 +7,21 @@ namespace LLama.Native public unsafe partial class NativeApi { - //todo: LLAMA_API struct llama_grammar * llama_grammar_init(const llama_grammar_element** rules, size_t n_rules,size_t start_rule_index); - /// - /// Free all memory from the given SafeLLamaGrammarHandle + /// Create a new grammar from the given set of grammar rules /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); + + /// + /// Free all memory from the given SafeLLamaGrammarHandle + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_grammar_free(IntPtr grammar); /// diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index ca814c36..0b4eda9d 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -1,4 +1,9 @@ using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using LLama.Exceptions; namespace LLama.Native { @@ -8,6 +13,11 @@ namespace LLama.Native public class SafeLLamaGrammarHandle : SafeLLamaHandleBase { + #region construction/destruction + /// + /// + /// + /// internal SafeLLamaGrammarHandle(IntPtr handle) : base(handle) { @@ -20,5 +30,76 @@ namespace LLama.Native SetHandle(IntPtr.Zero); return true; } + + /// + /// Create a new llama_grammar + /// + /// A list of list of elements, each inner list makes up one grammar rule + /// The index (in the outer list) of the start rule + /// + /// + public static SafeLLamaGrammarHandle Create(IReadOnlyList> rules, ulong start_rule_index) + { + unsafe + { + var totalElements = rules.Sum(a => a.Count); + var nrules = (ulong)rules.Count; + + // Borrow an array large enough to hold every single element + // and another array large enough to hold a pointer to each rule + var allElements = ArrayPool.Shared.Rent(totalElements); + var pointers = ArrayPool.Shared.Rent(rules.Count); + try + { + fixed (LLamaGrammarElement* allElementsPtr = allElements) + { + var elementIndex = 0; + var pointerIndex = 0; + foreach (var rule in rules) + { + // Save a pointer to the start of this rule + pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); + + // Copy all of the rule elements into the flat array + foreach (var element in rule) + allElementsPtr[elementIndex++] = element; + } + + // Sanity check some things that should be true if the copy worked as planned + Debug.Assert((ulong)pointerIndex == nrules); + Debug.Assert(elementIndex == totalElements); + + // Make the actual call through to llama.cpp + fixed (void* ptr = pointers) + { + return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); + } + } + } + finally + { + ArrayPool.Shared.Return(allElements); + ArrayPool.Shared.Return(pointers); + } + } + } + + /// + /// Create a new llama_grammar + /// + /// rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element) + /// total number of rules + /// index of the start rule of the grammar + /// + /// + public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) + { + var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); + if (grammar_ptr == IntPtr.Zero) + throw new RuntimeError("Failed to create grammar from rules"); + + return new(grammar_ptr); + } + #endregion } } diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index abe01b15..fde2311b 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -5,6 +5,18 @@ namespace LLama.Native using llama_token = Int32; public unsafe class SamplingApi { + /// + /// Apply grammar rules to candidate tokens + /// + /// + /// + /// + public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) + { + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_grammar(ctx, ref st, grammar); + } + /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. ///