Browse Source

Fixed unit tests

tags/v0.5.1
Martin Evans 2 years ago
parent
commit
9fc17f3136
3 changed files with 26 additions and 10 deletions
  1. +4
    -3
      LLama.Unittest/BasicTest.cs
  2. +7
    -7
      LLama.Unittest/GrammarTest.cs
  3. +15
    -0
      LLama/LLamaWeights.cs

+ 4
- 3
LLama.Unittest/BasicTest.cs View File

@@ -21,10 +21,11 @@ namespace LLama.Unittest
} }


[Fact] [Fact]
public void LoadModel()
public void BasicModelProperties()
{ {
var model = _model.CreateContext(_params, Encoding.UTF8);
model.Dispose();
Assert.Equal(32000, _model.VocabCount);
Assert.Equal(2048, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
} }
} }
} }

+ 7
- 7
LLama.Unittest/GrammarTest.cs View File

@@ -40,22 +40,22 @@ namespace LLama.Unittest
[Fact] [Fact]
public void SampleWithTrivialGrammar() public void SampleWithTrivialGrammar()
{ {
// Create a grammar that constrains the output to be "one" and nothing else
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar!
var rules = new List<List<LLamaGrammarElement>> var rules = new List<List<LLamaGrammarElement>>
{ {
new() new()
{ {
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'o'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'n'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'e'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
}, },
}; };


using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);


using var ctx = _model.CreateContext(_params, Encoding.UTF8);
var executor = new StatelessExecutor(ctx);
var executor = new StatelessExecutor(_model, _params);
var inferenceParams = new InferenceParams var inferenceParams = new InferenceParams
{ {
MaxTokens = 3, MaxTokens = 3,
@@ -65,7 +65,7 @@ namespace LLama.Unittest


var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();


Assert.Equal("one", result[0]);
Assert.Equal("cat", result[0]);
} }
} }
} }

+ 15
- 0
LLama/LLamaWeights.cs View File

@@ -25,6 +25,21 @@ namespace LLama
/// </summary> /// </summary>
public Encoding Encoding { get; } public Encoding Encoding { get; }


/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => NativeHandle.VocabCount;

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => NativeHandle.ContextSize;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize;

internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
{ {
_weights = weights; _weights = weights;


Loading…
Cancel
Save