From 9fc17f3136aed9a322d10a1ee20a2202af74c4f9 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 22 Aug 2023 14:16:20 +0100 Subject: [PATCH] Fixed unit tests --- LLama.Unittest/BasicTest.cs | 7 ++++--- LLama.Unittest/GrammarTest.cs | 14 +++++++------- LLama/LLamaWeights.cs | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index f3fed804..6ad2e8a9 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -21,10 +21,11 @@ namespace LLama.Unittest } [Fact] - public void LoadModel() + public void BasicModelProperties() { - var model = _model.CreateContext(_params, Encoding.UTF8); - model.Dispose(); + Assert.Equal(32000, _model.VocabCount); + Assert.Equal(2048, _model.ContextSize); + Assert.Equal(4096, _model.EmbeddingSize); } } } \ No newline at end of file diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index c2b60a1a..081fe99c 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -40,22 +40,22 @@ namespace LLama.Unittest [Fact] public void SampleWithTrivialGrammar() { - // Create a grammar that constrains the output to be "one" and nothing else + // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so + // we can be confident it's not what the LLM would say if not constrained by the grammar! var rules = new List> { new() { - new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'o'), - new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'n'), - new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'e'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), new LLamaGrammarElement(LLamaGrammarElementType.END, 0), }, }; using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); - using var ctx = _model.CreateContext(_params, Encoding.UTF8); - var executor = new StatelessExecutor(ctx); + var executor = new StatelessExecutor(_model, _params); var inferenceParams = new InferenceParams { MaxTokens = 3, @@ -65,7 +65,7 @@ namespace LLama.Unittest var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); - Assert.Equal("one", result[0]); + Assert.Equal("cat", result[0]); } } } diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 7d0ba1b0..8997e9c4 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -25,6 +25,21 @@ namespace LLama /// public Encoding Encoding { get; } + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount => NativeHandle.VocabCount; + + /// + /// Total number of tokens in the context + /// + public int ContextSize => NativeHandle.ContextSize; + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingSize => NativeHandle.EmbeddingSize; + internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) { _weights = weights;