- Integrated grammar into sampling - Added a test for the grammar samplingtags/v0.5.1
| @@ -0,0 +1,62 @@ | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| public sealed class GrammarTest | |||||
| : IDisposable | |||||
| { | |||||
| private readonly LLamaModel _model = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048)); | |||||
| public void Dispose() | |||||
| { | |||||
| _model.Dispose(); | |||||
| } | |||||
| [Fact] | |||||
| public void CreateBasicGrammar() | |||||
| { | |||||
| var rules = new List<List<LLamaGrammarElement>> | |||||
| { | |||||
| new() | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }, | |||||
| }; | |||||
| using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | |||||
| } | |||||
| [Fact] | |||||
| public void SampleWithTrivialGrammar() | |||||
| { | |||||
| // Create a grammar that constrains the output to be "one" and nothing else | |||||
| var rules = new List<List<LLamaGrammarElement>> | |||||
| { | |||||
| new() | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'o'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'n'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'e'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }, | |||||
| }; | |||||
| using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | |||||
| var executor = new StatelessExecutor(_model); | |||||
| var inferenceParams = new InferenceParams | |||||
| { | |||||
| MaxTokens = 3, | |||||
| AntiPrompts = new [] { ".", "Input:", "\n" }, | |||||
| Grammar = grammar, | |||||
| }; | |||||
| var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); | |||||
| Assert.Equal("one", result[0]); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Native; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| @@ -95,5 +96,10 @@ namespace LLama.Web.Common | |||||
| /// consider newlines as a repeatable token (penalize_nl) | /// consider newlines as a repeatable token (penalize_nl) | ||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNL { get; set; } = true; | public bool PenalizeNL { get; set; } = true; | ||||
| } | |||||
| /// <summary> | |||||
| /// A grammar to constrain possible tokens | |||||
| /// </summary> | |||||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||||
| } | |||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -113,5 +114,10 @@ namespace LLama.Abstractions | |||||
| /// consider newlines as a repeatable token (penalize_nl) | /// consider newlines as a repeatable token (penalize_nl) | ||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNL { get; set; } | public bool PenalizeNL { get; set; } | ||||
| /// <summary> | |||||
| /// Grammar to constrain possible tokens | |||||
| /// </summary> | |||||
| SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Native; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| @@ -96,6 +97,11 @@ namespace LLama.Common | |||||
| /// consider newlines as a repeatable token (penalize_nl) | /// consider newlines as a repeatable token (penalize_nl) | ||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNL { get; set; } = true; | public bool PenalizeNL { get; set; } = true; | ||||
| /// <summary> | |||||
| /// A grammar to constrain the possible tokens | |||||
| /// </summary> | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -293,11 +293,19 @@ namespace LLama | |||||
| /// <param name="topP"></param> | /// <param name="topP"></param> | ||||
| /// <param name="tfsZ"></param> | /// <param name="tfsZ"></param> | ||||
| /// <param name="typicalP"></param> | /// <param name="typicalP"></param> | ||||
| /// <param name="grammar"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | ||||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) | |||||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, | |||||
| SafeLLamaGrammarHandle? grammar = null) | |||||
| { | { | ||||
| llama_token id; | llama_token id; | ||||
| if (grammar != null) | |||||
| { | |||||
| SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); | |||||
| } | |||||
| if (temperature <= 0) | if (temperature <= 0) | ||||
| { | { | ||||
| // Greedy sampling | // Greedy sampling | ||||
| @@ -331,6 +339,12 @@ namespace LLama | |||||
| } | } | ||||
| mirostat_mu = mu; | mirostat_mu = mu; | ||||
| } | } | ||||
| if (grammar != null) | |||||
| { | |||||
| NativeApi.llama_grammar_accept_token(_ctx, grammar, id); | |||||
| } | |||||
| return id; | return id; | ||||
| } | } | ||||
| @@ -217,7 +217,8 @@ namespace LLama | |||||
| var mu = MirostatMu; | var mu = MirostatMu; | ||||
| var id = Context.Sample( | var id = Context.Sample( | ||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | ||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||||
| inferenceParams.Grammar | |||||
| ); | ); | ||||
| MirostatMu = mu; | MirostatMu = mu; | ||||
| @@ -206,7 +206,8 @@ namespace LLama | |||||
| var mu = MirostatMu; | var mu = MirostatMu; | ||||
| var id = Context.Sample( | var id = Context.Sample( | ||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | ||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||||
| inferenceParams.Grammar | |||||
| ); | ); | ||||
| MirostatMu = mu; | MirostatMu = mu; | ||||
| @@ -72,7 +72,7 @@ namespace LLama | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | ||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); | |||||
| lastTokens.Add(id); | lastTokens.Add(id); | ||||
| @@ -33,14 +33,14 @@ namespace LLama.Native | |||||
| CHAR_NOT = 4, | CHAR_NOT = 4, | ||||
| /// <summary> | /// <summary> | ||||
| /// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to | |||||
| /// modifies a preceding CHAR or CHAR_ALT to | |||||
| /// be an inclusive range ([a-z]) | /// be an inclusive range ([a-z]) | ||||
| /// </summary> | /// </summary> | ||||
| CHAR_RNG_UPPER = 5, | CHAR_RNG_UPPER = 5, | ||||
| /// <summary> | /// <summary> | ||||
| /// modifies a preceding LLAMA_GRETYPE_CHAR or | |||||
| /// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |||||
| /// modifies a preceding CHAR or | |||||
| /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |||||
| /// </summary> | /// </summary> | ||||
| CHAR_ALT = 6, | CHAR_ALT = 6, | ||||
| }; | }; | ||||
| @@ -60,5 +60,16 @@ namespace LLama.Native | |||||
| /// Unicode code point or rule ID | /// Unicode code point or rule ID | ||||
| /// </summary> | /// </summary> | ||||
| public uint Value; | public uint Value; | ||||
| /// <summary> | |||||
| /// Construct a new LLamaGrammarElement | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <param name="value"></param> | |||||
| public LLamaGrammarElement(LLamaGrammarElementType type, uint value) | |||||
| { | |||||
| Type = type; | |||||
| Value = value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,13 +7,21 @@ namespace LLama.Native | |||||
| public unsafe partial class NativeApi | public unsafe partial class NativeApi | ||||
| { | { | ||||
| //todo: LLAMA_API struct llama_grammar * llama_grammar_init(const llama_grammar_element** rules, size_t n_rules,size_t start_rule_index); | |||||
| /// <summary> | /// <summary> | ||||
| /// Free all memory from the given SafeLLamaGrammarHandle | |||||
| /// Create a new grammar from the given set of grammar rules | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="grammar"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| /// <param name="rules"></param> | |||||
| /// <param name="n_rules"></param> | |||||
| /// <param name="start_rule_index"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); | |||||
| /// <summary> | |||||
| /// Free all memory from the given SafeLLamaGrammarHandle | |||||
| /// </summary> | |||||
| /// <param name="grammar"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_grammar_free(IntPtr grammar); | public static extern void llama_grammar_free(IntPtr grammar); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,4 +1,9 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using LLama.Exceptions; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| @@ -8,6 +13,11 @@ namespace LLama.Native | |||||
| public class SafeLLamaGrammarHandle | public class SafeLLamaGrammarHandle | ||||
| : SafeLLamaHandleBase | : SafeLLamaHandleBase | ||||
| { | { | ||||
| #region construction/destruction | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="handle"></param> | |||||
| internal SafeLLamaGrammarHandle(IntPtr handle) | internal SafeLLamaGrammarHandle(IntPtr handle) | ||||
| : base(handle) | : base(handle) | ||||
| { | { | ||||
| @@ -20,5 +30,76 @@ namespace LLama.Native | |||||
| SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
| return true; | return true; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a new llama_grammar | |||||
| /// </summary> | |||||
| /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param> | |||||
| /// <param name="start_rule_index">The index (in the outer list) of the start rule</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| var totalElements = rules.Sum(a => a.Count); | |||||
| var nrules = (ulong)rules.Count; | |||||
| // Borrow an array large enough to hold every single element | |||||
| // and another array large enough to hold a pointer to each rule | |||||
| var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements); | |||||
| var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count); | |||||
| try | |||||
| { | |||||
| fixed (LLamaGrammarElement* allElementsPtr = allElements) | |||||
| { | |||||
| var elementIndex = 0; | |||||
| var pointerIndex = 0; | |||||
| foreach (var rule in rules) | |||||
| { | |||||
| // Save a pointer to the start of this rule | |||||
| pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); | |||||
| // Copy all of the rule elements into the flat array | |||||
| foreach (var element in rule) | |||||
| allElementsPtr[elementIndex++] = element; | |||||
| } | |||||
| // Sanity check some things that should be true if the copy worked as planned | |||||
| Debug.Assert((ulong)pointerIndex == nrules); | |||||
| Debug.Assert(elementIndex == totalElements); | |||||
| // Make the actual call through to llama.cpp | |||||
| fixed (void* ptr = pointers) | |||||
| { | |||||
| return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); | |||||
| } | |||||
| } | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<LLamaGrammarElement>.Shared.Return(allElements); | |||||
| ArrayPool<IntPtr>.Shared.Return(pointers); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a new llama_grammar | |||||
| /// </summary> | |||||
| /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param> | |||||
| /// <param name="nrules">total number of rules</param> | |||||
| /// <param name="start_rule_index">index of the start rule of the grammar</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) | |||||
| { | |||||
| var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); | |||||
| if (grammar_ptr == IntPtr.Zero) | |||||
| throw new RuntimeError("Failed to create grammar from rules"); | |||||
| return new(grammar_ptr); | |||||
| } | |||||
| #endregion | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,6 +5,18 @@ namespace LLama.Native | |||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public unsafe class SamplingApi | public unsafe class SamplingApi | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Apply grammar rules to candidate tokens | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <param name="grammar"></param> | |||||
| public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | |||||
| { | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||
| /// </summary> | /// </summary> | ||||