From db7ecf5a43a0472ed87c07e22e50a9a878d9df5c Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 15 Dec 2023 23:01:05 +0000 Subject: [PATCH] Added a method to create a clone of a grammar instance --- LLama.Unittest/GrammarTest.cs | 3 +- LLama/Native/NativeApi.Grammar.cs | 8 ++++++ LLama/Native/NativeApi.cs | 40 +++++++++++++++++++++++--- LLama/Native/SafeLLamaGrammarHandle.cs | 9 ++++++ 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index 3d7d1dad..870ed05a 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -74,13 +74,14 @@ namespace LLama.Unittest var grammar = new Grammar(rules, 0); using var grammarInstance = grammar.CreateInstance(); + using var grammarInstance2 = grammarInstance.Clone(); var executor = new StatelessExecutor(_model, _params); var inferenceParams = new InferenceParams { MaxTokens = 3, AntiPrompts = new [] { ".", "Input:", "\n" }, - Grammar = grammarInstance, + Grammar = grammarInstance2, }; var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index 354ade3b..84e298c7 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -24,6 +24,14 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_grammar_free(IntPtr grammar); + /// + /// Create a copy of an existing grammar instance + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr llama_grammar_copy(SafeLLamaGrammarHandle grammar); + /// /// Apply constraints from grammar /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index ca6027fc..24b9f571 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -254,6 +254,12 @@ namespace LLama.Native } } + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token); @@ -330,6 +336,34 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_add_eos_token(SafeLlamaModelHandle model); + /// + /// codellama infill tokens, Beginning of infill prefix + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_token_prefix(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, Beginning of infill middle + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_token_middle(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, Beginning of infill suffix + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_token_suffix(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, End of infill middle + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_token_eot(SafeLlamaModelHandle model); + /// /// Print out timing information for this context /// @@ -485,13 +519,11 @@ namespace LLama.Native public static extern void llama_log_set(LLamaLogCallback logCallback); /// - /// Remove all tokens data of cells in [c0, c1) + /// Clear the KV cache /// /// - /// - /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1); + public static extern void llama_kv_cache_clear(SafeLLamaContextHandle ctx); /// /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index ee27befd..49096d44 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -105,6 +105,15 @@ namespace LLama.Native } #endregion + /// + /// Create a copy of this grammar instance + /// + /// + public SafeLLamaGrammarHandle Clone() + { + return new SafeLLamaGrammarHandle(NativeApi.llama_grammar_copy(this)); + } + /// /// Accepts the sampled token into the grammar ///