Browse Source

- Created a slightly nicer way to create grammar (from `IReadOnlyList<IReadOnlyList<LLamaGrammarElement>>`)

- Integrated grammar into sampling
 - Added a test for the grammar sampling
tags/v0.5.1
Martin Evans 2 years ago
parent
commit
64416ca23c
12 changed files with 221 additions and 13 deletions
  1. +62
    -0
      LLama.Unittest/GrammarTest.cs
  2. +7
    -1
      LLama.Web/Common/ParameterOptions.cs
  3. +6
    -0
      LLama/Abstractions/IInferenceParams.cs
  4. +6
    -0
      LLama/Common/InferenceParams.cs
  5. +15
    -1
      LLama/LLamaContext.cs
  6. +2
    -1
      LLama/LLamaInstructExecutor.cs
  7. +2
    -1
      LLama/LLamaInteractExecutor.cs
  8. +1
    -1
      LLama/LLamaStatelessExecutor.cs
  9. +14
    -3
      LLama/Native/LLamaGrammarElement.cs
  10. +13
    -5
      LLama/Native/NativeApi.Grammar.cs
  11. +81
    -0
      LLama/Native/SafeLLamaGrammarHandle.cs
  12. +12
    -0
      LLama/Native/SamplingApi.cs

+ 62
- 0
LLama.Unittest/GrammarTest.cs View File

@@ -0,0 +1,62 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Unittest
{
public sealed class GrammarTest
: IDisposable
{
private readonly LLamaModel _model = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048));

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void CreateBasicGrammar()
{
var rules = new List<List<LLamaGrammarElement>>
{
new()
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
};

using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
}

[Fact]
public void SampleWithTrivialGrammar()
{
// Create a grammar that constrains the output to be "one" and nothing else
var rules = new List<List<LLamaGrammarElement>>
{
new()
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'o'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'n'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'e'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
};

using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);

var executor = new StatelessExecutor(_model);
var inferenceParams = new InferenceParams
{
MaxTokens = 3,
AntiPrompts = new [] { ".", "Input:", "\n" },
Grammar = grammar,
};

var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();

Assert.Equal("one", result[0]);
}
}
}

+ 7
- 1
LLama.Web/Common/ParameterOptions.cs View File

@@ -1,5 +1,6 @@
using LLama.Common;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Web.Common
{
@@ -95,5 +96,10 @@ namespace LLama.Web.Common
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
public bool PenalizeNL { get; set; } = true;
}

/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
}
}

+ 6
- 0
LLama/Abstractions/IInferenceParams.cs View File

@@ -1,5 +1,6 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;

namespace LLama.Abstractions
{
@@ -113,5 +114,10 @@ namespace LLama.Abstractions
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
public bool PenalizeNL { get; set; }

/// <summary>
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }
}
}

+ 6
- 0
LLama/Common/InferenceParams.cs View File

@@ -1,6 +1,7 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Common
{
@@ -96,6 +97,11 @@ namespace LLama.Common
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
public bool PenalizeNL { get; set; } = true;

/// <summary>
/// A grammar to constrain the possible tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }
}

/// <summary>


+ 15
- 1
LLama/LLamaContext.cs View File

@@ -293,11 +293,19 @@ namespace LLama
/// <param name="topP"></param>
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <param name="grammar"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
SafeLLamaGrammarHandle? grammar = null)
{
llama_token id;

if (grammar != null)
{
SamplingApi.llama_sample_grammar(_ctx, candidates, grammar);
}

if (temperature <= 0)
{
// Greedy sampling
@@ -331,6 +339,12 @@ namespace LLama
}
mirostat_mu = mu;
}

if (grammar != null)
{
NativeApi.llama_grammar_accept_token(_ctx, grammar, id);
}

return id;
}



+ 2
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -217,7 +217,8 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
);
MirostatMu = mu;



+ 2
- 1
LLama/LLamaInteractExecutor.cs View File

@@ -206,7 +206,8 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
);
MirostatMu = mu;



+ 1
- 1
LLama/LLamaStatelessExecutor.cs View File

@@ -72,7 +72,7 @@ namespace LLama
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);

lastTokens.Add(id);



+ 14
- 3
LLama/Native/LLamaGrammarElement.cs View File

@@ -33,14 +33,14 @@ namespace LLama.Native
CHAR_NOT = 4,

/// <summary>
/// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
/// modifies a preceding CHAR or CHAR_ALT to
/// be an inclusive range ([a-z])
/// </summary>
CHAR_RNG_UPPER = 5,

/// <summary>
/// modifies a preceding LLAMA_GRETYPE_CHAR or
/// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
/// modifies a preceding CHAR or
/// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
/// </summary>
CHAR_ALT = 6,
};
@@ -60,5 +60,16 @@ namespace LLama.Native
/// Unicode code point or rule ID
/// </summary>
public uint Value;

/// <summary>
/// Construct a new LLamaGrammarElement
/// </summary>
/// <param name="type"></param>
/// <param name="value"></param>
public LLamaGrammarElement(LLamaGrammarElementType type, uint value)
{
Type = type;
Value = value;
}
}
}

+ 13
- 5
LLama/Native/NativeApi.Grammar.cs View File

@@ -7,13 +7,21 @@ namespace LLama.Native

public unsafe partial class NativeApi
{
//todo: LLAMA_API struct llama_grammar * llama_grammar_init(const llama_grammar_element** rules, size_t n_rules,size_t start_rule_index);

/// <summary>
/// Free all memory from the given SafeLLamaGrammarHandle
/// Create a new grammar from the given set of grammar rules
/// </summary>
/// <param name="grammar"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
/// <param name="rules"></param>
/// <param name="n_rules"></param>
/// <param name="start_rule_index"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);

/// <summary>
/// Free all memory from the given SafeLLamaGrammarHandle
/// </summary>
/// <param name="grammar"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_grammar_free(IntPtr grammar);

/// <summary>


+ 81
- 0
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -1,4 +1,9 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using LLama.Exceptions;

namespace LLama.Native
{
@@ -8,6 +13,11 @@ namespace LLama.Native
public class SafeLLamaGrammarHandle
: SafeLLamaHandleBase
{
#region construction/destruction
/// <summary>
///
/// </summary>
/// <param name="handle"></param>
internal SafeLLamaGrammarHandle(IntPtr handle)
: base(handle)
{
@@ -20,5 +30,76 @@ namespace LLama.Native
SetHandle(IntPtr.Zero);
return true;
}

/// <summary>
/// Create a new llama_grammar
/// </summary>
/// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
/// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index)
{
unsafe
{
var totalElements = rules.Sum(a => a.Count);
var nrules = (ulong)rules.Count;

// Borrow an array large enough to hold every single element
// and another array large enough to hold a pointer to each rule
var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
try
{
fixed (LLamaGrammarElement* allElementsPtr = allElements)
{
var elementIndex = 0;
var pointerIndex = 0;
foreach (var rule in rules)
{
// Save a pointer to the start of this rule
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);

// Copy all of the rule elements into the flat array
foreach (var element in rule)
allElementsPtr[elementIndex++] = element;
}

// Sanity check some things that should be true if the copy worked as planned
Debug.Assert((ulong)pointerIndex == nrules);
Debug.Assert(elementIndex == totalElements);

// Make the actual call through to llama.cpp
fixed (void* ptr = pointers)
{
return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
}
}
}
finally
{
ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
ArrayPool<IntPtr>.Shared.Return(pointers);
}
}
}

/// <summary>
/// Create a new llama_grammar
/// </summary>
/// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
/// <param name="nrules">total number of rules</param>
/// <param name="start_rule_index">index of the start rule of the grammar</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
{
var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
if (grammar_ptr == IntPtr.Zero)
throw new RuntimeError("Failed to create grammar from rules");

return new(grammar_ptr);
}
#endregion
}
}

+ 12
- 0
LLama/Native/SamplingApi.cs View File

@@ -5,6 +5,18 @@ namespace LLama.Native
using llama_token = Int32;
public unsafe class SamplingApi
{
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="grammar"></param>
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>


Loading…
Cancel
Save