Browse Source

Merge pull request #370 from martindevans/copy_grammar

Clone Grammar
tags/0.9.1
Martin Evans GitHub 1 year ago
parent
commit
cbc4c8d9af
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 5 deletions
  1. +2
    -1
      LLama.Unittest/GrammarTest.cs
  2. +8
    -0
      LLama/Native/NativeApi.Grammar.cs
  3. +36
    -4
      LLama/Native/NativeApi.cs
  4. +9
    -0
      LLama/Native/SafeLLamaGrammarHandle.cs

+ 2
- 1
LLama.Unittest/GrammarTest.cs View File

@@ -74,13 +74,14 @@ namespace LLama.Unittest

var grammar = new Grammar(rules, 0);
using var grammarInstance = grammar.CreateInstance();
using var grammarInstance2 = grammarInstance.Clone();

var executor = new StatelessExecutor(_model, _params);
var inferenceParams = new InferenceParams
{
MaxTokens = 3,
AntiPrompts = new [] { ".", "Input:", "\n" },
Grammar = grammarInstance,
Grammar = grammarInstance2,
};

var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();


+ 8
- 0
LLama/Native/NativeApi.Grammar.cs View File

@@ -24,6 +24,14 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_grammar_free(IntPtr grammar);

/// <summary>
/// Create a copy of an existing grammar instance
/// </summary>
/// <param name="grammar"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_copy(SafeLLamaGrammarHandle grammar);

/// <summary>
/// Apply constraints from grammar
/// </summary>


+ 36
- 4
LLama/Native/NativeApi.cs View File

@@ -254,6 +254,12 @@ namespace LLama.Native
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token);

@@ -330,6 +336,34 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_add_eos_token(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, Beginning of infill prefix
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_prefix(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, Beginning of infill middle
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_middle(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, Beginning of infill suffix
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_suffix(SafeLlamaModelHandle model);

/// <summary>
/// codellama infill tokens, End of infill middle
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_eot(SafeLlamaModelHandle model);

/// <summary>
/// Print out timing information for this context
/// </summary>
@@ -485,13 +519,11 @@ namespace LLama.Native
public static extern void llama_log_set(LLamaLogCallback logCallback);

/// <summary>
/// Remove all tokens data of cells in [c0, c1)
/// Clear the KV cache
/// </summary>
/// <param name="ctx"></param>
/// <param name="c0"></param>
/// <param name="c1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1);
public static extern void llama_kv_cache_clear(SafeLLamaContextHandle ctx);

/// <summary>
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)


+ 9
- 0
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -105,6 +105,15 @@ namespace LLama.Native
}
#endregion

/// <summary>
/// Create a copy of this grammar instance
/// </summary>
/// <returns></returns>
public SafeLLamaGrammarHandle Clone()
{
return new SafeLLamaGrammarHandle(NativeApi.llama_grammar_copy(this));
}

/// <summary>
/// Accepts the sampled token into the grammar
/// </summary>


Loading…
Cancel
Save