Browse Source

Added a method to create a clone of a grammar instance

tags/0.9.1
Martin Evans 1 year ago
parent
commit
db7ecf5a43
4 changed files with 55 additions and 5 deletions
  1. +2
    -1
      LLama.Unittest/GrammarTest.cs
  2. +8
    -0
      LLama/Native/NativeApi.Grammar.cs
  3. +36
    -4
      LLama/Native/NativeApi.cs
  4. +9
    -0
      LLama/Native/SafeLLamaGrammarHandle.cs

+ 2
- 1
LLama.Unittest/GrammarTest.cs View File

@@ -74,13 +74,14 @@ namespace LLama.Unittest


var grammar = new Grammar(rules, 0); var grammar = new Grammar(rules, 0);
using var grammarInstance = grammar.CreateInstance(); using var grammarInstance = grammar.CreateInstance();
using var grammarInstance2 = grammarInstance.Clone();


var executor = new StatelessExecutor(_model, _params); var executor = new StatelessExecutor(_model, _params);
var inferenceParams = new InferenceParams var inferenceParams = new InferenceParams
{ {
MaxTokens = 3, MaxTokens = 3,
AntiPrompts = new [] { ".", "Input:", "\n" }, AntiPrompts = new [] { ".", "Input:", "\n" },
Grammar = grammarInstance,
Grammar = grammarInstance2,
}; };


var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();


+ 8
- 0
LLama/Native/NativeApi.Grammar.cs View File

@@ -24,6 +24,14 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_grammar_free(IntPtr grammar); public static extern void llama_grammar_free(IntPtr grammar);


/// <summary>
/// Create a copy of an existing grammar instance
/// </summary>
/// <param name="grammar"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_copy(SafeLLamaGrammarHandle grammar);

/// <summary> /// <summary>
/// Apply constraints from grammar /// Apply constraints from grammar
/// </summary> /// </summary>


+ 36
- 4
LLama/Native/NativeApi.cs View File

@@ -254,6 +254,12 @@ namespace LLama.Native
} }
} }


[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token); public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token);


@@ -330,6 +336,34 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_add_eos_token(SafeLlamaModelHandle model); public static extern int llama_add_eos_token(SafeLlamaModelHandle model);


/// <summary>
/// codellama infill tokens, Beginning of infill prefix
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_prefix(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, Beginning of infill middle
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_middle(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, Beginning of infill suffix
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_suffix(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, End of infill middle
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_eot(SafeLlamaModelHandle model);

/// <summary> /// <summary>
/// Print out timing information for this context /// Print out timing information for this context
/// </summary> /// </summary>
@@ -485,13 +519,11 @@ namespace LLama.Native
public static extern void llama_log_set(LLamaLogCallback logCallback); public static extern void llama_log_set(LLamaLogCallback logCallback);


/// <summary> /// <summary>
/// Remove all tokens data of cells in [c0, c1)
/// Clear the KV cache
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="c0"></param>
/// <param name="c1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1);
public static extern void llama_kv_cache_clear(SafeLLamaContextHandle ctx);


/// <summary> /// <summary>
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)


+ 9
- 0
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -105,6 +105,15 @@ namespace LLama.Native
} }
#endregion #endregion


/// <summary>
/// Create a copy of this grammar instance
/// </summary>
/// <returns></returns>
public SafeLLamaGrammarHandle Clone()
{
return new SafeLLamaGrammarHandle(NativeApi.llama_grammar_copy(this));
}

/// <summary> /// <summary>
/// Accepts the sampled token into the grammar /// Accepts the sampled token into the grammar
/// </summary> /// </summary>


Loading…
Cancel
Save