diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml new file mode 100644 index 00000000..c3bcbe2a --- /dev/null +++ b/.github/workflows/compile.yml @@ -0,0 +1,240 @@ +name: Update Binaries + +on: + workflow_dispatch: + inputs: + cublas: + type: boolean + description: Build CUBLAS binaries + macos: + type: boolean + description: Build MacOS binaries + push: + branches: [cron_job] + #schedule: + # - cron: "22 22 * * 2" + +jobs: + compile-linux: + name: Compile (Linux) + strategy: + fail-fast: true + matrix: + include: + - build: 'noavx' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON' + - build: 'avx2' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON' + - build: 'avx' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DBUILD_SHARED_LIBS=ON' + - build: 'avx512' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + repository: ggerganov/llama.cpp + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake .. ${{ matrix.defines }} + cmake --build . --config Release + - uses: actions/upload-artifact@v3 + with: + path: ./build/libllama.so + name: llama-bin-linux-${{ matrix.build }}-x64.so + + compile-windows: + name: Compile (Windows) + strategy: + fail-fast: true + matrix: + include: + - build: 'noavx' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON' + - build: 'avx2' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON' + - build: 'avx' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DBUILD_SHARED_LIBS=ON' + - build: 'avx512' + defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON' + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + with: + repository: ggerganov/llama.cpp + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake .. ${{ matrix.defines }} + cmake --build . --config Release + + - name: Upload artifacts + uses: actions/upload-artifact@v3 + with: + path: .\build\bin\Release\llama.dll + name: llama-bin-win-${{ matrix.build }}-x64.dll + + compile-cublas: + if: ${{ github.event.inputs.cublas }} + name: Compile (cublas) + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest] + cuda: ['12.1.0', '11.7.1'] + runs-on: ${{ matrix.os }} + steps: + - name: Clone + id: checkout + uses: actions/checkout@v3 + with: + repository: ggerganov/llama.cpp + + - uses: Jimver/cuda-toolkit@v0.2.10 + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda }} + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake .. -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF + cmake --build . --config Release + ls -R + + - name: Upload artifacts (Windows) + if: ${{ matrix.os == 'windows-latest' }} + uses: actions/upload-artifact@v3 + with: + path: .\build\bin\Release\llama.dll + name: llama-bin-win-cublas-cu${{ matrix.cuda }}-x64.dll + - name: Upload artifacts (Linux) + if: ${{ matrix.os == 'ubuntu-latest' }} + uses: actions/upload-artifact@v3 + with: + path: ./build/libllama.so + name: llama-bin-linux-cublas-cu${{ matrix.cuda }}-x64.so + + compile-macos: + if: ${{ github.event.inputs.macos }} + name: Compile (MacOS) + runs-on: macos-latest + strategy: + fail-fast: true + matrix: + arch: [ + "arm64" + ] + + steps: + - uses: actions/checkout@v3 + with: + repository: ggerganov/llama.cpp + + - name: Dependencies + continue-on-error: true + run: | + brew update + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_OSX_ARCHITECTURES=${{ matrix.arch }} .. + cmake --build . --config Release + + - name: Upload artifacts + uses: actions/upload-artifact@v3 + with: + path: ./build/libllama.dylib + name: llama-bin-macos-${{ matrix.arch }}.dylib + + build-deps: + runs-on: ubuntu-latest + name: "Gather Binaries" + if: ${{ always() }} + needs: [ + "compile-linux", + "compile-macos", + "compile-windows", + "compile-cublas" + ] + steps: + - uses: actions/download-artifact@v3 + with: + path: artifacts + + - name: Rearrange Files + run: | + ls -R + + mkdir deps + + mkdir deps/linux + mkdir deps/linux/noavx + cp artifacts/llama-bin-linux-noavx-x64.so/libllama.so deps/linux/noavx/libllama.so + mkdir deps/linux/avx + cp artifacts/llama-bin-linux-avx-x64.so/libllama.so deps/linux/avx/libllama.so + mkdir deps/linux/avx2 + cp artifacts/llama-bin-linux-avx2-x64.so/libllama.so deps/linux/avx2/libllama.so + mkdir deps/linux/avx512 + cp artifacts/llama-bin-linux-avx512-x64.so/libllama.so deps/linux/avx512/libllama.so + + mkdir deps/win + mkdir deps/win/noavx + cp artifacts/llama-bin-win-noavx-x64.dll/llama.dll deps/win/noavx/libllama.dll + mkdir deps/win/avx + cp artifacts/llama-bin-win-avx-x64.dll/llama.dll deps/win/avx/libllama.dll + mkdir deps/win/avx2 + cp artifacts/llama-bin-win-avx2-x64.dll/llama.dll deps/win/avx2/libllama.dll + mkdir deps/win/avx512 + cp artifacts/llama-bin-win-avx512-x64.dll/llama.dll deps/win/avx512/libllama.dll + + - name: Rearrange MacOS files + if: ${{ github.event.inputs.macos }} + run: | + mkdir deps/macos-arm64 + cp artifacts/llama-bin-macos-arm64.dylib/libllama.dylib deps/macos-arm64/libllama.dylib + + - name: Rearrange CUDA files + if: ${{ github.event.inputs.cublas }} + run: | + mkdir cu11.7.1 + cp artifacts/llama-bin-win-cublas-cu11.7.1-x64.dll/llama.dll cu11.7.1/libllama.dll + cp artifacts/llama-bin-linux-cublas-cu11.7.1-x64.so/libllama.so cu11.7.1/libllama.so + mkdir cu12.1.0 + cp artifacts/llama-bin-win-cublas-cu12.1.0-x64.dll/llama.dll cu12.1.0/libllama.dll + cp artifacts/llama-bin-linux-cublas-cu12.1.0-x64.so/libllama.so cu12.1.0/libllama.so + + - name: Upload artifacts + uses: actions/upload-artifact@v3 + with: + path: deps/ + name: deps + - name: Upload artifacts (CUDA12) + if: ${{ github.event.inputs.cublas }} + uses: actions/upload-artifact@v3 + with: + path: cu12.1.0/ + name: cu12.1.0 + - name: Upload artifacts (CUDA11) + if: ${{ github.event.inputs.cublas }} + uses: actions/upload-artifact@v3 + with: + path: cu11.7.1/ + name: cu11.7.1 + + - name: Remove Artifacts + uses: geekyeggo/delete-artifact@v2 + with: + name: | + llama-* diff --git a/.gitignore b/.gitignore index e7c87968..f7b8be30 100644 --- a/.gitignore +++ b/.gitignore @@ -344,4 +344,5 @@ test/TensorFlowNET.Examples/mnist site/ /LLama.Unittest/Models/*.bin +/LLama.Unittest/Models/*.gguf diff --git a/LLama.Examples/Assets/json.gbnf b/LLama.Examples/Assets/json.gbnf new file mode 100644 index 00000000..a01c4efd --- /dev/null +++ b/LLama.Examples/Assets/json.gbnf @@ -0,0 +1,27 @@ +# https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf + +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? \ No newline at end of file diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index ef7ac437..a8abe3ae 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -27,6 +27,11 @@ + + + + + @@ -49,6 +54,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index ce677c40..65ac8d91 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,15 +7,27 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. The role names won't be printed."); Console.ForegroundColor = ConsoleColor.White; + // show the prompt + Console.Write(prompt); while (true) { foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index cbf9333b..dcbcc07b 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +7,20 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var session = new ChatSession(executor); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); diff --git a/LLama.Examples/NewVersion/GetEmbeddings.cs b/LLama.Examples/NewVersion/GetEmbeddings.cs index ed12f868..516d2da7 100644 --- a/LLama.Examples/NewVersion/GetEmbeddings.cs +++ b/LLama.Examples/NewVersion/GetEmbeddings.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,7 +7,7 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var embedder = new LLamaEmbedder(new ModelParams(modelPath)); while (true) diff --git a/LLama.Examples/NewVersion/GrammarJsonResponse.cs b/LLama.Examples/NewVersion/GrammarJsonResponse.cs new file mode 100644 index 00000000..a3c147f5 --- /dev/null +++ b/LLama.Examples/NewVersion/GrammarJsonResponse.cs @@ -0,0 +1,53 @@ +using LLama.Common; +using LLama.Grammars; + +namespace LLama.Examples.NewVersion +{ + public class GrammarJsonResponse + { + public static void Run() + { + var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); + var grammar = Grammar.Parse(gbnf, "root"); + + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + var ex = new StatelessExecutor(model, parameters); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); + Console.ForegroundColor = ConsoleColor.White; + + using var grammarInstance = grammar.CreateInstance(); + var inferenceParams = new InferenceParams() + { + Temperature = 0.6f, + AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" }, + MaxTokens = 50, + Grammar = grammarInstance + }; + + while (true) + { + Console.Write("\nQuestion: "); + Console.ForegroundColor = ConsoleColor.Green; + var prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Answer: "); + prompt = $"Question: {prompt?.Trim()} Answer: "; + foreach (var text in ex.Infer(prompt, inferenceParams)) + { + Console.Write(text); + } + } + } + } +} diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index 303c8644..b0e325f1 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +7,18 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/dan.txt").Trim(); - InstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024))); + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InstructExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + @@ -26,7 +29,7 @@ namespace LLama.Examples.NewVersion while (true) { - foreach (var text in ex.Infer(prompt, inferenceParams)) + foreach (var text in executor.Infer(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/InteractiveModeExecute.cs b/LLama.Examples/NewVersion/InteractiveModeExecute.cs index 23afcadf..8bc002eb 100644 --- a/LLama.Examples/NewVersion/InteractiveModeExecute.cs +++ b/LLama.Examples/NewVersion/InteractiveModeExecute.cs @@ -1,21 +1,24 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { public class InteractiveModeExecute { - public async static Task Run() + public static async Task Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + var modelPath = Console.ReadLine(); + var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 722ec3e0..948ac6cd 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -1,10 +1,4 @@ using LLama.Common; -using LLama.OldVersion; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -13,10 +7,20 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); + + var session = new ChatSession(ex); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result. Input \"save\" to save and reload the session."); @@ -45,8 +49,8 @@ namespace LLama.Examples.NewVersion Console.WriteLine("Saved session!"); Console.ForegroundColor = ConsoleColor.White; - ex.Model.Dispose(); - ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + ex.Context.Dispose(); + ex = new(new LLamaContext(parameters)); session = new ChatSession(ex); session.LoadSession(statePath); diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index dc303141..e7e0d4ef 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +7,18 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); @@ -36,20 +39,20 @@ namespace LLama.Examples.NewVersion if (prompt == "save") { Console.Write("Your path to save model state: "); - string modelStatePath = Console.ReadLine(); - ex.Model.SaveState(modelStatePath); + var modelStatePath = Console.ReadLine(); + ex.Context.SaveState(modelStatePath); Console.Write("Your path to save executor state: "); - string executorStatePath = Console.ReadLine(); + var executorStatePath = Console.ReadLine(); ex.SaveState(executorStatePath); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("All states saved!"); Console.ForegroundColor = ConsoleColor.White; - var model = ex.Model; - model.LoadState(modelStatePath); - ex = new InteractiveExecutor(model); + var ctx = ex.Context; + ctx.LoadState(modelStatePath); + ex = new InteractiveExecutor(ctx); ex.LoadState(executorStatePath); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Loaded state!"); diff --git a/LLama.Examples/NewVersion/QuantizeModel.cs b/LLama.Examples/NewVersion/QuantizeModel.cs index a5ad81d8..71966af8 100644 --- a/LLama.Examples/NewVersion/QuantizeModel.cs +++ b/LLama.Examples/NewVersion/QuantizeModel.cs @@ -1,11 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace LLama.Examples.NewVersion +namespace LLama.Examples.NewVersion { public class QuantizeModel { @@ -13,13 +6,16 @@ namespace LLama.Examples.NewVersion { Console.Write("Please input your original model path: "); var inputPath = Console.ReadLine(); + Console.Write("Please input your output model path: "); var outputPath = Console.ReadLine(); + Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): "); var quantizeType = Console.ReadLine(); + if (LLamaQuantizer.Quantize(inputPath, outputPath, quantizeType)) { - Console.WriteLine("Quantization succeed!"); + Console.WriteLine("Quantization succeeded!"); } else { diff --git a/LLama.Examples/NewVersion/SemanticKernelChat.cs b/LLama.Examples/NewVersion/SemanticKernelChat.cs new file mode 100644 index 00000000..9bdbcfec --- /dev/null +++ b/LLama.Examples/NewVersion/SemanticKernelChat.cs @@ -0,0 +1,69 @@ +using System.Reflection.Metadata; +using System.Security.Cryptography; +using System.Text; +using LLama.Abstractions; +using LLama.Common; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.AI.ChatCompletion; +using Microsoft.SemanticKernel.AI.TextCompletion; +using LLamaSharp.SemanticKernel.ChatCompletion; +using LLamaSharp.SemanticKernel.TextCompletion; + +namespace LLama.Examples.NewVersion +{ + public class SemanticKernelChat + { + public static async Task Run() + { + Console.WriteLine("Example from: https://github.com/microsoft/semantic-kernel/blob/main/dotnet/README.md"); + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + // Load weights into memory + var parameters = new ModelParams(modelPath) + { + Seed = RandomNumberGenerator.GetInt32(int.MaxValue), + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); + + var chatGPT = new LLamaSharpChatCompletion(ex); + + var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books"); + + Console.WriteLine("Chat content:"); + Console.WriteLine("------------------------"); + + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); + await MessageOutputAsync(chatHistory); + + // First bot assistant message + string reply = await chatGPT.GenerateMessageAsync(chatHistory); + chatHistory.AddAssistantMessage(reply); + await MessageOutputAsync(chatHistory); + + // Second user message + chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + await MessageOutputAsync(chatHistory); + + // Second bot assistant message + reply = await chatGPT.GenerateMessageAsync(chatHistory); + chatHistory.AddAssistantMessage(reply); + await MessageOutputAsync(chatHistory); + } + + /// + /// Outputs the last message of the chat history + /// + private static Task MessageOutputAsync(Microsoft.SemanticKernel.AI.ChatCompletion.ChatHistory chatHistory) + { + var message = chatHistory.Messages.Last(); + + Console.WriteLine($"{message.Role}: {message.Content}"); + Console.WriteLine("------------------------"); + + return Task.CompletedTask; + } + } +} diff --git a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs new file mode 100644 index 00000000..0482c195 --- /dev/null +++ b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs @@ -0,0 +1,55 @@ +using System.Reflection.Metadata; +using System.Security.Cryptography; +using System.Text; +using LLama.Abstractions; +using LLama.Common; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.AI.ChatCompletion; +using Microsoft.SemanticKernel.AI.TextCompletion; +using LLamaSharp.SemanticKernel.TextCompletion; + +namespace LLama.Examples.NewVersion +{ + public class SemanticKernelPrompt + { + public static async Task Run() + { + Console.WriteLine("Example from: https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/KernelSyntaxExamples/Example17_ChatGPT.cs"); + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + // Load weights into memory + var parameters = new ModelParams(modelPath) + { + Seed = RandomNumberGenerator.GetInt32(int.MaxValue), + }; + using var model = LLamaWeights.LoadFromFile(parameters); + var ex = new StatelessExecutor(model, parameters); + + var builder = new KernelBuilder(); + builder.WithAIService("local-llama", new LLamaSharpTextCompletion(ex), true); + + var kernel = builder.Build(); + + var prompt = @"{{$input}} + +One line TLDR with the fewest words."; + + var summarize = kernel.CreateSemanticFunction(prompt, maxTokens: 100); + + string text1 = @" +1st Law of Thermodynamics - Energy cannot be created or destroyed. +2nd Law of Thermodynamics - For a spontaneous process, the entropy of the universe increases. +3rd Law of Thermodynamics - A perfect crystal at zero Kelvin has zero entropy."; + + string text2 = @" +1. An object at rest remains at rest, and an object in motion remains in motion at constant speed and in a straight line unless acted on by an unbalanced force. +2. The acceleration of an object depends on the mass of the object and the amount of force applied. +3. Whenever one object exerts a force on another object, the second object exerts an equal and opposite on the first."; + + Console.WriteLine(await summarize.InvokeAsync(text1)); + + Console.WriteLine(await summarize.InvokeAsync(text2)); + } + } +} diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index 3c485231..ddd6227f 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,9 +7,16 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); - StatelessExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + var ex = new StatelessExecutor(model, parameters); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + @@ -29,10 +31,10 @@ namespace LLama.Examples.NewVersion { Console.Write("\nQuestion: "); Console.ForegroundColor = ConsoleColor.Green; - string prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; + var prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); - prompt = $"Question: {prompt.Trim()} Answer: "; + prompt = $"Question: {prompt?.Trim()} Answer: "; foreach (var text in ex.Infer(prompt, inferenceParams)) { Console.Write(text); diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs new file mode 100644 index 00000000..309d5654 --- /dev/null +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -0,0 +1,74 @@ +using System.Security.Cryptography; +using System.Text; +using LLama.Abstractions; +using LLama.Common; + +namespace LLama.Examples.NewVersion +{ + public class TalkToYourself + { + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + // Load weights into memory + var @params = new ModelParams(modelPath) + { + Seed = RandomNumberGenerator.GetInt32(int.MaxValue) + }; + using var weights = LLamaWeights.LoadFromFile(@params); + + // Create 2 contexts sharing the same weights + using var aliceCtx = weights.CreateContext(@params); + var alice = new InteractiveExecutor(aliceCtx); + using var bobCtx = weights.CreateContext(@params); + var bob = new InteractiveExecutor(bobCtx); + + // Initial alice prompt + var alicePrompt = "Transcript of a dialog, where the Alice interacts a person named Bob. Alice is friendly, kind, honest and good at writing.\nAlice: Hello"; + var aliceResponse = await Prompt(alice, ConsoleColor.Green, alicePrompt, false, false); + + // Initial bob prompt + var bobPrompt = $"Transcript of a dialog, where the Bob interacts a person named Alice. Bob is smart, intellectual and good at writing.\nAlice: Hello{aliceResponse}"; + var bobResponse = await Prompt(bob, ConsoleColor.Red, bobPrompt, true, true); + + // swap back and forth from Alice to Bob + while (true) + { + aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true); + bobResponse = await Prompt(bob, ConsoleColor.Red, aliceResponse, false, true); + + if (Console.KeyAvailable) + break; + } + } + + private static async Task Prompt(ILLamaExecutor executor, ConsoleColor color, string prompt, bool showPrompt, bool showResponse) + { + var inferenceParams = new InferenceParams + { + Temperature = 0.9f, + AntiPrompts = new List { "Alice:", "Bob:", "User:" }, + MaxTokens = 128, + Mirostat = MirostatType.Mirostat2, + MirostatTau = 10, + }; + + Console.ForegroundColor = ConsoleColor.White; + if (showPrompt) + Console.Write(prompt); + + Console.ForegroundColor = color; + var builder = new StringBuilder(); + await foreach (var text in executor.InferAsync(prompt, inferenceParams)) + { + builder.Append(text); + if (showResponse) + Console.Write(text); + } + + return builder.ToString(); + } + } +} diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index 23c9ae6b..83316510 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -1,10 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace LLama.Examples.NewVersion +namespace LLama.Examples.NewVersion { public class NewVersionTestRunner { @@ -14,7 +8,7 @@ namespace LLama.Examples.NewVersion Console.WriteLine("Please input a number to choose an example to run:"); Console.WriteLine("0: Run a chat session without stripping the role names."); - Console.WriteLine("1: Run a chat session with the role names strippped."); + Console.WriteLine("1: Run a chat session with the role names stripped."); Console.WriteLine("2: Interactive mode chat by using executor."); Console.WriteLine("3: Instruct mode chat by using executor."); Console.WriteLine("4: Stateless mode chat by using executor."); @@ -22,6 +16,10 @@ namespace LLama.Examples.NewVersion Console.WriteLine("6: Load and save state of model and executor."); Console.WriteLine("7: Get embeddings from LLama model."); Console.WriteLine("8: Quantize the model."); + Console.WriteLine("9: Automatic conversation."); + Console.WriteLine("10: Constrain response to json format using grammar."); + Console.WriteLine("11: Semantic Kernel Prompt."); + Console.WriteLine("12: Semantic Kernel Chat."); while (true) { @@ -64,6 +62,22 @@ namespace LLama.Examples.NewVersion { QuantizeModel.Run(); } + else if (choice == 9) + { + await TalkToYourself.Run(); + } + else if (choice == 10) + { + GrammarJsonResponse.Run(); + } + else if (choice == 11) + { + await SemanticKernelPrompt.Run(); + } + else if (choice == 12) + { + await SemanticKernelChat.Run(); + } else { Console.WriteLine("Cannot parse your choice. Please select again."); diff --git a/LLama.Examples/OldVersion/ChatSession.cs b/LLama.Examples/OldVersion/ChatSession.cs index 52216803..3f851532 100644 --- a/LLama.Examples/OldVersion/ChatSession.cs +++ b/LLama.Examples/OldVersion/ChatSession.cs @@ -7,6 +7,7 @@ using LLama.OldVersion; namespace LLama.Examples.Old { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public class ChatSession { LLama.OldVersion.ChatSession _session; diff --git a/LLama.Examples/OldVersion/ChatWithLLamaModel.cs b/LLama.Examples/OldVersion/ChatWithLLamaModel.cs index 452b5b2d..88adebc7 100644 --- a/LLama.Examples/OldVersion/ChatWithLLamaModel.cs +++ b/LLama.Examples/OldVersion/ChatWithLLamaModel.cs @@ -7,6 +7,7 @@ using LLama.OldVersion; namespace LLama.Examples.Old { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public class ChatWithLLamaModel { LLama.OldVersion.LLamaModel _model; diff --git a/LLama.Examples/OldVersion/GetEmbeddings.cs b/LLama.Examples/OldVersion/GetEmbeddings.cs index df620ea1..8dd28109 100644 --- a/LLama.Examples/OldVersion/GetEmbeddings.cs +++ b/LLama.Examples/OldVersion/GetEmbeddings.cs @@ -7,6 +7,7 @@ using LLama.OldVersion; namespace LLama.Examples.Old { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public class GetEmbeddings { LLama.OldVersion.LLamaEmbedder _embedder; diff --git a/LLama.Examples/OldVersion/InstructMode.cs b/LLama.Examples/OldVersion/InstructMode.cs index 2b954e3f..ce123366 100644 --- a/LLama.Examples/OldVersion/InstructMode.cs +++ b/LLama.Examples/OldVersion/InstructMode.cs @@ -7,6 +7,7 @@ using LLama.OldVersion; namespace LLama.Examples.Old { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public class InstructMode { LLama.OldVersion.LLamaModel _model; diff --git a/LLama.Examples/OldVersion/SaveAndLoadState.cs b/LLama.Examples/OldVersion/SaveAndLoadState.cs index bcb77409..abe6bbfb 100644 --- a/LLama.Examples/OldVersion/SaveAndLoadState.cs +++ b/LLama.Examples/OldVersion/SaveAndLoadState.cs @@ -7,6 +7,7 @@ using LLama.OldVersion; namespace LLama.Examples.Old { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public class SaveAndLoadState: IDisposable { LLama.OldVersion.LLamaModel _model; diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 9e5baaef..1699fa66 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -1,7 +1,4 @@ -using LLama; -using LLama.Common; -using LLama.Examples; -using LLama.Examples.NewVersion; +using LLama.Examples.NewVersion; using LLama.Examples.Old; Console.WriteLine("======================================================================================================"); diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs new file mode 100644 index 00000000..759888d0 --- /dev/null +++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs @@ -0,0 +1,17 @@ +using static LLama.LLamaTransforms; + +namespace LLamaSharp.SemanticKernel.ChatCompletion; + +/// +/// Default HistoryTransform Patch +/// +public class HistoryTransform : DefaultHistoryTransform +{ + /// + public override string HistoryToText(global::LLama.Common.ChatHistory history) + { + var prompt = base.HistoryToText(history); + return prompt + "\nAssistant:"; + + } +} diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs new file mode 100644 index 00000000..7fda3d4f --- /dev/null +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -0,0 +1,74 @@ +using LLama; +using Microsoft.SemanticKernel.AI.ChatCompletion; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace LLamaSharp.SemanticKernel.ChatCompletion; + +/// +/// LLamaSharp ChatCompletion +/// +public sealed class LLamaSharpChatCompletion : IChatCompletion +{ + private const string UserRole = "user:"; + private const string AssistantRole = "assistant:"; + private ChatSession session; + + public LLamaSharpChatCompletion(InteractiveExecutor model) + { + this.session = new ChatSession(model) + .WithHistoryTransform(new HistoryTransform()) + .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole })); + } + + /// + public ChatHistory CreateNewChat(string? instructions = "") + { + var history = new ChatHistory(); + + if (instructions != null && !string.IsNullOrEmpty(instructions)) + { + history.AddSystemMessage(instructions); + } + + return history; + } + + /// + public async Task> GetChatCompletionsAsync(ChatHistory chat, ChatRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + requestSettings ??= new ChatRequestSettings() + { + MaxTokens = 256, + Temperature = 0, + TopP = 0, + StopSequences = new List { } + }; + + var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); + + return new List { new LLamaSharpChatResult(result) }.AsReadOnly(); + } + + /// + public async IAsyncEnumerable GetStreamingChatCompletionsAsync(ChatHistory chat, ChatRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + requestSettings ??= new ChatRequestSettings() + { + MaxTokens = 256, + Temperature = 0, + TopP = 0, + StopSequences = new List { } + }; + + var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); + + yield return new LLamaSharpChatResult(result); + } +} diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatMessage.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatMessage.cs new file mode 100644 index 00000000..1e54d0a1 --- /dev/null +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatMessage.cs @@ -0,0 +1,14 @@ +using Microsoft.SemanticKernel.AI.ChatCompletion; + +namespace LLamaSharp.SemanticKernel.ChatCompletion; + +/// +/// LLamaSharp Chat Message +/// +public class LLamaSharpChatMessage : ChatMessageBase +{ + /// + public LLamaSharpChatMessage(AuthorRole role, string content) : base(role, content) + { + } +} diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatResult.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatResult.cs new file mode 100644 index 00000000..ec479f42 --- /dev/null +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatResult.cs @@ -0,0 +1,38 @@ +using Microsoft.SemanticKernel.AI.ChatCompletion; +using System.Runtime.CompilerServices; +using System.Text; + +namespace LLamaSharp.SemanticKernel.ChatCompletion; + +internal sealed class LLamaSharpChatResult : IChatStreamingResult +{ + private readonly IAsyncEnumerable _stream; + + /// + /// + /// + /// + public LLamaSharpChatResult(IAsyncEnumerable stream) + { + _stream = stream; + } + /// + public async Task GetChatMessageAsync(CancellationToken cancellationToken = default) + { + var sb = new StringBuilder(); + await foreach (var token in _stream) + { + sb.Append(token); + } + return await Task.FromResult(new LLamaSharpChatMessage(AuthorRole.Assistant, sb.ToString())).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetStreamingChatMessageAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var token in _stream) + { + yield return new LLamaSharpChatMessage(AuthorRole.Assistant, token); + } + } +} diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs new file mode 100644 index 00000000..90090ead --- /dev/null +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -0,0 +1,72 @@ +using Microsoft.SemanticKernel.AI.ChatCompletion; +using Microsoft.SemanticKernel.AI.TextCompletion; + +namespace LLamaSharp.SemanticKernel; + +internal static class ExtensionMethods +{ + internal static global::LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory chatHistory) + { + if (chatHistory is null) + { + throw new ArgumentNullException(nameof(chatHistory)); + } + + var history = new global::LLama.Common.ChatHistory(); + + foreach (var chat in chatHistory) + { + var role = Enum.TryParse(chat.Role.Label, out var _role) ? _role : global::LLama.Common.AuthorRole.Unknown; + history.AddMessage(role, chat.Content); + } + + return history; + } + + /// + /// Convert ChatRequestSettings to LLamaSharp InferenceParams + /// + /// + /// + internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings) + { + if (requestSettings is null) + { + throw new ArgumentNullException(nameof(requestSettings)); + } + + var antiPrompts = new List(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" }; + return new global::LLama.Common.InferenceParams + { + Temperature = (float)requestSettings.Temperature, + TopP = (float)requestSettings.TopP, + PresencePenalty = (float)requestSettings.PresencePenalty, + FrequencyPenalty = (float)requestSettings.FrequencyPenalty, + AntiPrompts = antiPrompts, + MaxTokens = requestSettings.MaxTokens ?? -1 + }; + } + + /// + /// Convert CompleteRequestSettings to LLamaSharp InferenceParams + /// + /// + /// + internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this CompleteRequestSettings requestSettings) + { + if (requestSettings is null) + { + throw new ArgumentNullException(nameof(requestSettings)); + } + + return new global::LLama.Common.InferenceParams + { + Temperature = (float)requestSettings.Temperature, + TopP = (float)requestSettings.TopP, + PresencePenalty = (float)requestSettings.PresencePenalty, + FrequencyPenalty = (float)requestSettings.FrequencyPenalty, + AntiPrompts = requestSettings.StopSequences, + MaxTokens = requestSettings.MaxTokens ?? -1 + }; + } +} diff --git a/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj b/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj new file mode 100644 index 00000000..fc5af9b1 --- /dev/null +++ b/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj @@ -0,0 +1,22 @@ + + + + netstandard2.0;net6.0;net7.0 + LLamaSharp.SemanticKernel + enable + 10 + AnyCPU;x64;Arm64 + True + enable + enable + + + + + + + + + + + diff --git a/LLama.SemanticKernel/README.md b/LLama.SemanticKernel/README.md new file mode 100644 index 00000000..369968b0 --- /dev/null +++ b/LLama.SemanticKernel/README.md @@ -0,0 +1,26 @@ +# LLamaSharp.SemanticKernel + +LLamaSharp.SemanticKernel are connections for [SemanticKernel](https://github.com/microsoft/semantic-kernel): an SDK for intergrating various LLM interfaces into a single implementation. With this, you can add local LLaMa queries as another connection point with your existing connections. + +For reference on how to implement it, view the following examples: + +- [SemanticKernelChat](../LLama.Examples/NewVersion/SemanticKernelChat.cs) +- [SemanticKernelPrompt](../LLama.Examples/NewVersion/SemanticKernelPrompt.cs) + +## ITextCompletion +```csharp +using var model = LLamaWeights.LoadFromFile(parameters); +// LLamaSharpTextCompletion can accept ILLamaExecutor. +var ex = new StatelessExecutor(model, parameters); +var builder = new KernelBuilder(); +builder.WithAIService("local-llama", new LLamaSharpTextCompletion(ex), true); +``` + +## IChatCompletion +```csharp +using var model = LLamaWeights.LoadFromFile(parameters); +using var context = model.CreateContext(parameters); +// LLamaSharpChatCompletion requires InteractiveExecutor, as it's the best fit for the given command. +var ex = new InteractiveExecutor(context); +var chatGPT = new LLamaSharpChatCompletion(ex); +``` diff --git a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs new file mode 100644 index 00000000..40dbd3f8 --- /dev/null +++ b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs @@ -0,0 +1,27 @@ +using LLama; +using LLama.Abstractions; +using Microsoft.SemanticKernel.AI.TextCompletion; + +namespace LLamaSharp.SemanticKernel.TextCompletion; + +public sealed class LLamaSharpTextCompletion : ITextCompletion +{ + public ILLamaExecutor executor; + + public LLamaSharpTextCompletion(ILLamaExecutor executor) + { + this.executor = executor; + } + + public async Task> GetCompletionsAsync(string text, CompleteRequestSettings requestSettings, CancellationToken cancellationToken = default) + { + var result = executor.InferAsync(text, requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); + return await Task.FromResult(new List { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false); + } + + public async IAsyncEnumerable GetStreamingCompletionsAsync(string text, CompleteRequestSettings requestSettings, CancellationToken cancellationToken = default) + { + var result = executor.InferAsync(text, requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); + yield return new LLamaTextResult(result); + } +} diff --git a/LLama.SemanticKernel/TextCompletion/LLamaTextResult.cs b/LLama.SemanticKernel/TextCompletion/LLamaTextResult.cs new file mode 100644 index 00000000..e1643481 --- /dev/null +++ b/LLama.SemanticKernel/TextCompletion/LLamaTextResult.cs @@ -0,0 +1,37 @@ +using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Orchestration; +using System.Runtime.CompilerServices; +using System.Text; + +namespace LLamaSharp.SemanticKernel.TextCompletion; + +internal sealed class LLamaTextResult : ITextStreamingResult +{ + private readonly IAsyncEnumerable _text; + + public LLamaTextResult(IAsyncEnumerable text) + { + _text = text; + ModelResult = new(text); + } + + public ModelResult ModelResult { get; } + + public async Task GetCompletionAsync(CancellationToken cancellationToken = default) + { + var sb = new StringBuilder(); + await foreach (var token in _text) + { + sb.Append(token); + } + return await Task.FromResult(sb.ToString()).ConfigureAwait(false); + } + + public async IAsyncEnumerable GetCompletionStreamingAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (string word in _text) + { + yield return word; + } + } +} diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 308b13ad..832f3fdd 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -1,15 +1,60 @@ -using LLama; using LLama.Common; namespace LLama.Unittest { public class BasicTest + : IDisposable { + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public BasicTest() + { + _params = new ModelParams(Constants.ModelPath) + { + ContextSize = 2048 + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + [Fact] - public void LoadModel() + public void BasicModelProperties() { - var model = new LLamaModel(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256)); - model.Dispose(); + Assert.Equal(32000, _model.VocabCount); + Assert.Equal(2048, _model.ContextSize); + Assert.Equal(4096, _model.EmbeddingSize); + } + + [Fact] + public void CloneContext() + { + var original = _model.CreateContext(_params); + + // Evaluate something (doesn't matter what, as long as it begins with token 1) + original.Eval(new[] { 1, 42, 321 }, 0); + + // Clone current state + var clone = original.Clone(); + + // Now evaluate something more + var reply1a = original.Eval(new[] { 4, 5, 6 }, 3); + var reply2a = original.Eval(new[] { 7, 8, 9 }, 6); + + // Assert that the context replied differently each time + Assert.NotEqual(reply1a, reply2a); + + // Give the same prompts to the cloned state + var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3); + var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6); + + // Assert that the cloned context replied in the same way as originally + Assert.Equal(reply1a, reply1b); + Assert.Equal(reply2a, reply2b); } } } \ No newline at end of file diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs new file mode 100644 index 00000000..21328b41 --- /dev/null +++ b/LLama.Unittest/Constants.cs @@ -0,0 +1,7 @@ +namespace LLama.Unittest +{ + internal static class Constants + { + public static string ModelPath = "Models/llama-2-7b.q4_0.gguf"; + } +} diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs new file mode 100644 index 00000000..6d3adb82 --- /dev/null +++ b/LLama.Unittest/GrammarParserTest.cs @@ -0,0 +1,241 @@ +using LLama.Native; +using LLama.Grammars; + +namespace LLama.Unittest +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + public sealed class GrammarParserTest + { + [Fact] + public void ParseComplexGrammar() + { + GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); + string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]+"; + + var state = parsedGrammar.Parse(grammarBytes, "root"); + Assert.Equal(0ul, state.StartRuleIndex); + + var expected = new List> + { + new KeyValuePair("expr", 2), + new KeyValuePair("expr_5", 5), + new KeyValuePair("expr_6", 6), + new KeyValuePair("root", 0), + new KeyValuePair("root_1", 1), + new KeyValuePair("root_4", 4), + new KeyValuePair("term", 3), + new KeyValuePair("term_7", 7), + }; + + foreach (var symbol in expected) + { + var rule = state.Rules[(int)symbol.Value]; + Assert.Equal(symbol.Key, rule.Name); + } + + var expectedRules = new List + { + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }; + + uint index = 0; + foreach (var rule in state.Rules) + { + // compare rule to expected rule + for (uint i = 0; i < rule.Elements.Count; i++) + { + var element = rule.Elements[(int)i]; + var expectedElement = expectedRules[(int)index]; + + // Pretty print error message before asserting + if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) + { + Console.Error.WriteLine($"index: {index}"); + Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); + Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); + Console.Error.WriteLine("expected_element != actual_element"); + } + Assert.Equal(expectedElement.Type, element.Type); + Assert.Equal(expectedElement.Value, element.Value); + index++; + } + } + Assert.NotEmpty(state.Rules); + } + + [Fact] + public void ParseExtraComplexGrammar() + { + GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); + string grammarBytes = @" + root ::= (expr ""="" ws term ""\n"")+ + expr ::= term ([-+*/] term)* + term ::= ident | num | ""("" ws expr "")"" ws + ident ::= [a-z] [a-z0-9_]* ws + num ::= [0-9]+ ws + ws ::= [ \t\n]* + "; + + var state = parsedGrammar.Parse(grammarBytes, "root"); + Assert.Equal(0ul, state.StartRuleIndex); + + var expected = new List> + { + new KeyValuePair("expr", 2), + new KeyValuePair("expr_6", 6), + new KeyValuePair("expr_7", 7), + new KeyValuePair("ident", 8), + new KeyValuePair("ident_10", 10), + new KeyValuePair("num", 9), + new KeyValuePair("num_11", 11), + new KeyValuePair("root", 0), + new KeyValuePair("root_1", 1), + new KeyValuePair("root_5", 5), + new KeyValuePair("term", 4), + new KeyValuePair("ws", 3), + new KeyValuePair("ws_12", 12), + }; + + foreach (var symbol in expected) + { + var rule = state.Rules[(int)symbol.Value]; + Assert.Equal(symbol.Key, rule.Name); + } + + var expectedRules = new List + { + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0) + }; + + uint index = 0; + foreach (var rule in state.Rules) + { + // compare rule to expected rule + for (uint i = 0; i < rule.Elements.Count; i++) + { + var element = rule.Elements[(int)i]; + var expectedElement = expectedRules[(int)index]; + + // Pretty print error message before asserting + if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) + { + Console.Error.WriteLine($"index: {index}"); + Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); + Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); + Console.Error.WriteLine("expected_element != actual_element"); + } + Assert.Equal(expectedElement.Type, element.Type); + Assert.Equal(expectedElement.Value, element.Value); + index++; + } + } + Assert.NotEmpty(state.Rules); + } + } +} diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs new file mode 100644 index 00000000..152ede93 --- /dev/null +++ b/LLama.Unittest/GrammarTest.cs @@ -0,0 +1,74 @@ +using LLama.Common; +using LLama.Grammars; +using LLama.Native; + +namespace LLama.Unittest +{ + public sealed class GrammarTest + : IDisposable + { + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public GrammarTest() + { + _params = new ModelParams(Constants.ModelPath) + { + ContextSize = 2048, + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact] + public void CreateBasicGrammar() + { + var rules = new List + { + new GrammarRule("alpha", new[] + { + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }), + }; + + using var handle = SafeLLamaGrammarHandle.Create(rules, 0); + } + + [Fact] + public void SampleWithTrivialGrammar() + { + // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so + // we can be confident it's not what the LLM would say if not constrained by the grammar! + var rules = new List + { + new GrammarRule("feline", new [] + { + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }), + }; + + using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); + + var executor = new StatelessExecutor(_model, _params); + var inferenceParams = new InferenceParams + { + MaxTokens = 3, + AntiPrompts = new [] { ".", "Input:", "\n" }, + Grammar = grammar, + }; + + var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); + + Assert.Equal("cat", result[0]); + } + } +} diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 81e71a88..ea0e100a 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -11,20 +11,20 @@ - - - + + + runtime; build; native; contentfiles; analyzers; buildtransitive all - + runtime; build; native; contentfiles; analyzers; buildtransitive all - + @@ -37,7 +37,7 @@ - + PreserveNewest diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs new file mode 100644 index 00000000..e9c84eac --- /dev/null +++ b/LLama.Unittest/LLamaContextTests.cs @@ -0,0 +1,35 @@ +using LLama.Common; + +namespace LLama.Unittest +{ + public class LLamaContextTests + : IDisposable + { + private readonly LLamaWeights _weights; + private readonly LLamaContext _context; + + public LLamaContextTests() + { + var @params = new ModelParams(Constants.ModelPath) + { + ContextSize = 768, + }; + _weights = LLamaWeights.LoadFromFile(@params); + _context = _weights.CreateContext(@params); + } + + public void Dispose() + { + _weights.Dispose(); + _context.Dispose(); + } + + [Fact] + public void CheckProperties() + { + Assert.Equal(768, _context.ContextSize); + Assert.Equal(4096, _context.EmbeddingSize); + Assert.Equal(32000, _context.VocabCount); + } + } +} diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs new file mode 100644 index 00000000..f94c90ba --- /dev/null +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -0,0 +1,71 @@ +using LLama.Common; + +namespace LLama.Unittest; + +public class LLamaEmbedderTests + : IDisposable +{ + private readonly LLamaEmbedder _embedder = new(new ModelParams(Constants.ModelPath)); + + public void Dispose() + { + _embedder.Dispose(); + } + + private static float Magnitude(float[] a) + { + return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum()); + } + + private static void Normalize(float[] a) + { + var mag = Magnitude(a); + for (var i = 0; i < a.Length; i++) + a[i] /= mag; + } + + private static float Dot(float[] a, float[] b) + { + Assert.Equal(a.Length, b.Length); + return a.Zip(b, (x, y) => x * y).Sum(); + } + + private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f) + { + for (int i = 0; i < expected.Length; i++) + Assert.Equal(expected[i], actual[i], epsilon); + } + + // todo: enable this one llama2 7B gguf is available + //[Fact] + //public void EmbedBasic() + //{ + // var cat = _embedder.GetEmbeddings("cat"); + + // Assert.NotNull(cat); + // Assert.NotEmpty(cat); + + // // Expected value generate with llama.cpp embedding.exe + // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f }; + // AssertApproxStartsWith(expected, cat); + //} + + [Fact] + public void EmbedCompare() + { + var cat = _embedder.GetEmbeddings("cat"); + var kitten = _embedder.GetEmbeddings("kitten"); + var spoon = _embedder.GetEmbeddings("spoon"); + + Normalize(cat); + Normalize(kitten); + Normalize(spoon); + + var close = Dot(cat, kitten); + var far = Dot(cat, spoon); + + // This comparison seems backwards, but remember that with a + // dot product 1.0 means **identical** and 0.0 means **completely opposite**! + Assert.True(close > far); + } +} \ No newline at end of file diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs new file mode 100644 index 00000000..413bda83 --- /dev/null +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -0,0 +1,70 @@ +using System.Text; +using LLama.Common; +using Newtonsoft.Json; + +namespace LLama.Unittest +{ + public class ModelsParamsTests + { + [Fact] + public void SerializeRoundTripSystemTextJson() + { + var expected = new ModelParams("abc/123") + { + BatchSize = 17, + ContextSize = 42, + LoraAdapter = "adapter", + Seed = 42, + GpuLayerCount = 111 + }; + + var json = System.Text.Json.JsonSerializer.Serialize(expected); + var actual = System.Text.Json.JsonSerializer.Deserialize(json); + + Assert.Equal(expected, actual); + } + + [Fact] + public void SerializeRoundTripNewtonsoft() + { + var expected = new ModelParams("abc/123") + { + BatchSize = 17, + ContextSize = 42, + LoraAdapter = "adapter", + Seed = 42, + GpuLayerCount = 111 + }; + + var settings = new Newtonsoft.Json.JsonSerializerSettings(); + settings.Converters.Add(new NewtsonsoftEncodingConverter()); + + var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings); + var actual = Newtonsoft.Json.JsonConvert.DeserializeObject(json, settings); + + Assert.Equal(expected, actual); + } + + + + public class NewtsonsoftEncodingConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return typeof(Encoding).IsAssignableFrom(objectType); + } + + public override void WriteJson(JsonWriter writer, object value, JsonSerializer serializer) + { + writer.WriteValue(((Encoding)value).WebName); + } + + public override object ReadJson(JsonReader reader, Type objectType, object existingValue, JsonSerializer serializer) + { + return Encoding.GetEncoding((string)reader.Value); + } + } + + + } +} diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs new file mode 100644 index 00000000..1748e02d --- /dev/null +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -0,0 +1,70 @@ +using LLama.Common; +using Xunit.Abstractions; + +namespace LLama.Unittest +{ + public class StatelessExecutorTest + : IDisposable + { + private readonly ITestOutputHelper _testOutputHelper; + private readonly LLamaWeights _weights; + private readonly ModelParams _params; + + public StatelessExecutorTest(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.ModelPath) + { + ContextSize = 60, + Seed = 1754 + }; + _weights = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _weights.Dispose(); + } + + [Fact] + public void Stateless() + { + var executor = new StatelessExecutor(_weights, _params); + + const string question = "Question. what is a cat?\nAnswer: "; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + + var result1 = string.Join("", executor.Infer(question, @params)); + var result2 = string.Join("", executor.Infer(question, @params)); + + _testOutputHelper.WriteLine(result1); + + // Check that it produced the exact same result both times + Assert.Equal(result1, result2); + } + + [Fact] + public void OutOfContext() + { + var executor = new StatelessExecutor(_weights, _params); + + const string question = " Question. why is a cat the best pet?\nAnswer: "; + + // The context size is set to 60. Generate more than that, forcing it to generate a coherent response + // with a modified context + var @params = new InferenceParams() + { + MaxTokens = 100, + TokensKeep = question.Length, + }; + + var result1 = string.Join("", executor.Infer(question, @params)); + var result2 = string.Join("", executor.Infer(question, @params)); + + _testOutputHelper.WriteLine(result1); + + // Check that it produced the exact same result both times + Assert.Equal(result1, result2); + } + } +} \ No newline at end of file diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index e8b89dee..f06757e3 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -1,8 +1,10 @@ -using LLama.Abstractions; +using System.Text; +using LLama.Abstractions; namespace LLama.Web.Common { - public class ModelOptions : IModelParams + public class ModelOptions + : IModelParams { public string Name { get; set; } @@ -86,16 +88,6 @@ namespace LLama.Web.Common /// public float[] TensorSplits { get; set; } - /// - /// Grouped-Query Attention - /// - public int GroupedQueryAttention { get; set; } = 1; - - /// - /// RMS Norm Epsilon - /// - public float RmsNormEpsilon { get; set; } = 5e-6f; - /// /// RoPE base frequency /// @@ -111,5 +103,9 @@ namespace LLama.Web.Common /// public bool MulMatQ { get; set; } - } + /// + /// The encoding to use for models + /// + public Encoding Encoding { get; set; } = Encoding.UTF8; + } } diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs index 7677f04a..f78aa861 100644 --- a/LLama.Web/Common/ParameterOptions.cs +++ b/LLama.Web/Common/ParameterOptions.cs @@ -1,5 +1,6 @@ using LLama.Common; using LLama.Abstractions; +using LLama.Native; namespace LLama.Web.Common { @@ -95,5 +96,10 @@ namespace LLama.Web.Common /// consider newlines as a repeatable token (penalize_nl) /// public bool PenalizeNL { get; set; } = true; - } + + /// + /// A grammar to constrain possible tokens + /// + public SafeLLamaGrammarHandle Grammar { get; set; } = null; + } } diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index d6d42813..c53676f2 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -60,7 +60,8 @@ namespace LLama.Web.Models { _inferenceOptions = null; _outputTransform = null; - _executor.Model?.Dispose(); + + _executor?.Context.Dispose(); _executor = null; } } diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs index 6c266f14..7dfcde39 100644 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ b/LLama.Web/Services/ConnectionSessionService.cs @@ -51,7 +51,7 @@ namespace LLama.Web.Services return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); // Create model - var llamaModel = new LLamaModel(modelOption); + var llamaModel = new LLamaContext(modelOption); // Create executor ILLamaExecutor executor = executorType switch diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ab89b517..a9ac3a44 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -8,7 +8,7 @@ namespace LLama.WebAPI.Services; public class StatefulChatService : IDisposable { private readonly ChatSession _session; - private readonly LLamaModel _model; + private readonly LLamaContext _context; private bool _continue = false; private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" @@ -16,13 +16,16 @@ public class StatefulChatService : IDisposable public StatefulChatService(IConfiguration configuration) { - _model = new LLamaModel(new Common.ModelParams(configuration["ModelPath"], contextSize: 512)); - _session = new ChatSession(new InteractiveExecutor(_model)); + _context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"]) + { + ContextSize = 512 + }); + _session = new ChatSession(new InteractiveExecutor(_context)); } public void Dispose() { - _model?.Dispose(); + _context?.Dispose(); } public string Send(SendMessageInput input) diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index c1356646..b924f4d8 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -7,14 +7,17 @@ namespace LLama.WebAPI.Services { public class StatelessChatService { - private readonly LLamaModel _model; + private readonly LLamaContext _context; private readonly ChatSession _session; public StatelessChatService(IConfiguration configuration) { - _model = new LLamaModel(new ModelParams(configuration["ModelPath"], contextSize: 512)); + _context = new LLamaContext(new ModelParams(configuration["ModelPath"]) + { + ContextSize = 512, + }); // TODO: replace with a stateless executor - _session = new ChatSession(new InteractiveExecutor(_model)) + _session = new ChatSession(new InteractiveExecutor(_context)) .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) .WithHistoryTransform(new HistoryTransform()); } diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index 75a78bc5..c9217ae0 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -1,7 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Text; namespace LLama.Abstractions { diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index 73cbbfd2..e576366f 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using LLama.Common; +using LLama.Native; namespace LLama.Abstractions { @@ -113,5 +114,10 @@ namespace LLama.Abstractions /// consider newlines as a repeatable token (penalize_nl) /// public bool PenalizeNL { get; set; } + + /// + /// Grammar to constrain possible tokens + /// + SafeLLamaGrammarHandle? Grammar { get; set; } } } \ No newline at end of file diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 6a750895..a7af0243 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -1,7 +1,4 @@ -using LLama.Common; -using System; -using System.Collections.Generic; -using System.Text; +using System.Collections.Generic; using System.Threading; namespace LLama.Abstractions @@ -12,9 +9,9 @@ namespace LLama.Abstractions public interface ILLamaExecutor { /// - /// The loaded model for this executor. + /// The loaded context for this executor. /// - public LLamaModel Model { get; } + public LLamaContext Context { get; } /// /// Infers a response from the model. diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index fdc91152..700d98e2 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -1,7 +1,10 @@ -using System; +using System.Text; namespace LLama.Abstractions { + /// + /// The parameters for initializing a LLama model. + /// public interface IModelParams { /// @@ -95,16 +98,6 @@ namespace LLama.Abstractions /// float[]? TensorSplits { get; set; } - /// - /// Grouped-Query Attention - /// - int GroupedQueryAttention { get; set; } - - /// - /// RMS Norm Epsilon - /// - float RmsNormEpsilon { get; set; } - /// /// RoPE base frequency /// @@ -119,5 +112,10 @@ namespace LLama.Abstractions /// Use experimental mul_mat_q kernels /// bool MulMatQ { get; set; } + + /// + /// The encoding to use for models + /// + Encoding Encoding { get; set; } } } \ No newline at end of file diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index af083564..e96febcf 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -1,6 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Text; +using System.Collections.Generic; namespace LLama.Abstractions { @@ -15,6 +13,7 @@ namespace LLama.Abstractions /// /// IEnumerable Transform(IEnumerable tokens); + /// /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. /// diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs index c165e807..ac196644 100644 --- a/LLama/Abstractions/ITextTransform.cs +++ b/LLama/Abstractions/ITextTransform.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama.Abstractions +namespace LLama.Abstractions { /// /// An interface for text transformations. diff --git a/LLama/AssemblyAttributes.cs b/LLama/AssemblyAttributes.cs new file mode 100644 index 00000000..dab58d12 --- /dev/null +++ b/LLama/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("LLama.Unittest")] \ No newline at end of file diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 4a4544b0..5ed6a459 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -5,6 +5,7 @@ using System.IO; using System.Runtime.CompilerServices; using System.Text; using System.Threading; +using System.Threading.Tasks; namespace LLama { @@ -13,10 +14,12 @@ namespace LLama /// public class ChatSession { - private ILLamaExecutor _executor; - private ChatHistory _history; - private static readonly string _executorStateFilename = "ExecutorState.json"; - private static readonly string _modelStateFilename = "ModelState.st"; + private readonly ILLamaExecutor _executor; + private readonly ChatHistory _history; + + private const string _executorStateFilename = "ExecutorState.json"; + private const string _modelStateFilename = "ModelState.st"; + /// /// The executor for this session. /// @@ -91,7 +94,7 @@ namespace LLama { Directory.CreateDirectory(path); } - _executor.Model.SaveState(Path.Combine(path, _modelStateFilename)); + _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); if(Executor is StatelessExecutor) { @@ -116,7 +119,7 @@ namespace LLama { throw new FileNotFoundException($"Directory {path} does not exist."); } - _executor.Model.LoadState(Path.Combine(path, _modelStateFilename)); + _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); if (Executor is StatelessExecutor) { @@ -227,7 +230,7 @@ namespace LLama private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); - await foreach (var item in OutputTransform.TransformAsync(results)) + await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) { yield return item; } diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 28ae53fc..2c331e5a 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -15,9 +15,20 @@ namespace LLama.Common private readonly int _maxSize; private readonly List _storage; + /// + /// Number of items in this queue + /// public int Count => _storage.Count; + + /// + /// Maximum number of items allowed in this queue + /// public int Capacity => _maxSize; + /// + /// Create a new queue + /// + /// the maximum number of items to store in this queue public FixedSizeQueue(int size) { _maxSize = size; diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 001a8f8e..64d2652b 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,6 +1,7 @@ using LLama.Abstractions; using System; using System.Collections.Generic; +using LLama.Native; namespace LLama.Common { @@ -96,6 +97,11 @@ namespace LLama.Common /// consider newlines as a repeatable token (penalize_nl) /// public bool PenalizeNL { get; set; } = true; + + /// + /// A grammar to constrain the possible tokens + /// + public SafeLLamaGrammarHandle? Grammar { get; set; } } /// diff --git a/LLama/Common/Logger.cs b/LLama/Common/Logger.cs index 9bcd927e..edff64f9 100644 --- a/LLama/Common/Logger.cs +++ b/LLama/Common/Logger.cs @@ -1,21 +1,44 @@ -using System; +using LLama.Native; +using System; using System.Diagnostics; using System.IO; using static LLama.Common.ILLamaLogger; namespace LLama.Common; +/// +/// receives log messages from LLamaSharp +/// public interface ILLamaLogger { + /// + /// Severity level of a log message + /// public enum LogLevel { - Info, - Debug, - Warning, - Error + /// + /// Logs that are used for interactive investigation during development. + /// + Debug = 1, + + /// + /// Logs that highlight when the current flow of execution is stopped due to a failure. + /// + Error = 2, + + /// + /// Logs that highlight an abnormal or unexpected event in the application flow, but do not otherwise cause the application execution to stop. + /// + Warning = 3, + + /// + /// Logs that track the general flow of the application. + /// + Info = 4 } + /// - /// Write the log in cosutomized way + /// Write the log in customized way /// /// The source of the log. It may be a method name or class name. /// The message. @@ -24,19 +47,23 @@ public interface ILLamaLogger } /// -/// The default logger of LLamaSharp. On default it write to console. User methods of `LLamaLogger.Default` to change the behavior. -/// It's more recommended to inherit `ILLamaLogger` to cosutomize the behavior. +/// The default logger of LLamaSharp. On default it write to console. Use methods of `LLamaLogger.Default` to change the behavior. +/// It's recommended to inherit `ILLamaLogger` to customize the behavior. /// -public sealed class LLamaDefaultLogger : ILLamaLogger +public sealed class LLamaDefaultLogger + : ILLamaLogger { private static readonly Lazy _instance = new Lazy(() => new LLamaDefaultLogger()); private bool _toConsole = true; - private bool _toFile = false; + private bool _toFile; - private FileStream? _fileStream = null; - private StreamWriter _fileWriter = null; + private FileStream? _fileStream; + private StreamWriter? _fileWriter; + /// + /// Get the default logger instance + /// public static LLamaDefaultLogger Default => _instance.Value; private LLamaDefaultLogger() @@ -44,18 +71,42 @@ public sealed class LLamaDefaultLogger : ILLamaLogger } + /// + /// Enable logging output from llama.cpp + /// + /// + public LLamaDefaultLogger EnableNative() + { + EnableNativeLogCallback(); + return this; + } + + /// + /// Enable writing log messages to console + /// + /// public LLamaDefaultLogger EnableConsole() { _toConsole = true; return this; } + /// + /// Disable writing messages to console + /// + /// public LLamaDefaultLogger DisableConsole() { _toConsole = false; return this; } + /// + /// Enable writing log messages to file + /// + /// + /// + /// public LLamaDefaultLogger EnableFile(string filename, FileMode mode = FileMode.Append) { _fileStream = new FileStream(filename, mode, FileAccess.Write); @@ -64,7 +115,22 @@ public sealed class LLamaDefaultLogger : ILLamaLogger return this; } + /// + /// Disable writing log messages to file + /// + /// unused! + /// + [Obsolete("Use DisableFile method without 'filename' parameter")] public LLamaDefaultLogger DisableFile(string filename) + { + return DisableFile(); + } + + /// + /// Disable writing log messages to file + /// + /// + public LLamaDefaultLogger DisableFile() { if (_fileWriter is not null) { @@ -80,6 +146,12 @@ public sealed class LLamaDefaultLogger : ILLamaLogger return this; } + /// + /// Log a message + /// + /// The source of this message (e.g. class name) + /// The message to log + /// Severity level of this message public void Log(string source, string message, LogLevel level) { if (level == LogLevel.Info) @@ -100,6 +172,10 @@ public sealed class LLamaDefaultLogger : ILLamaLogger } } + /// + /// Write a log message with "Info" severity + /// + /// public void Info(string message) { message = MessageFormat("info", message); @@ -117,6 +193,10 @@ public sealed class LLamaDefaultLogger : ILLamaLogger } } + /// + /// Write a log message with "Warn" severity + /// + /// public void Warn(string message) { message = MessageFormat("warn", message); @@ -134,6 +214,10 @@ public sealed class LLamaDefaultLogger : ILLamaLogger } } + /// + /// Write a log message with "Error" severity + /// + /// public void Error(string message) { message = MessageFormat("error", message); @@ -151,10 +235,36 @@ public sealed class LLamaDefaultLogger : ILLamaLogger } } - private string MessageFormat(string level, string message) + private static string MessageFormat(string level, string message) { - DateTime now = DateTime.Now; - string formattedDate = now.ToString("yyyy.MM.dd HH:mm:ss"); - return $"[{formattedDate}][{level}]: {message}"; + var now = DateTime.Now; + return $"[{now:yyyy.MM.dd HH:mm:ss}][{level}]: {message}"; } + + /// + /// Register native logging callback + /// + private void EnableNativeLogCallback() + { + // TODO: Move to a more appropriate place once we have a intitialize method + NativeApi.llama_log_set(NativeLogCallback); + } + + /// + /// Callback for native logging function + /// + /// The log level + /// The log message + private void NativeLogCallback(LogLevel level, string message) + { + if (string.IsNullOrEmpty(message)) + return; + + // Note that text includes the new line character at the end for most events. + // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it + // if it exists. + // It might not exist for progress report where '.' is output repeatedly. + Log(default!, message.TrimEnd('\n'), level); + } + } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 5cb81078..e0b0c264 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,12 +1,15 @@ using LLama.Abstractions; using System; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; namespace LLama.Common { /// /// The parameters for initializing a LLama model. /// - public class ModelParams + public record ModelParams : IModelParams { /// @@ -86,16 +89,6 @@ namespace LLama.Common /// public float[]? TensorSplits { get; set; } - /// - /// Grouped-Query Attention - /// - public int GroupedQueryAttention { get; set; } = 1; - - /// - /// RMS Norm Epsilon - /// - public float RmsNormEpsilon { get; set; } = 5e-6f; - /// /// RoPE base frequency /// @@ -111,34 +104,57 @@ namespace LLama.Common /// public bool MulMatQ { get; set; } - /// - /// - /// - /// The model path. - /// Model context size (n_ctx) - /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) - /// Seed for the random number generator (seed) - /// Whether to use f16 instead of f32 for memory kv (memory_f16) - /// Whether to use mmap for faster loads (use_mmap) - /// Whether to use mlock to keep model in memory (use_mlock) - /// Thether to compute perplexity over the prompt (perplexity) - /// Lora adapter path (lora_adapter) - /// Base model path for the lora adapter (lora_base) - /// Number of threads (-1 = autodetect) (n_threads) - /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// Whether to convert eos to newline during the inference. - /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. - /// Grouped-Query Attention - /// RMS Norm Epsilon - /// RoPE base frequency. - /// RoPE frequency scaling factor - /// Use experimental mul_mat_q kernels - public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, - int seed = 1337, bool useFp16Memory = true, - bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, - string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, - bool convertEosToNewLine = false, bool embeddingMode = false, - int gqa = 1, float rmsNormEps = 5e-6f, float rope_freq_base = 10000.0f, float rope_freq_scale = 1f, bool muMatQ = false) + /// + /// The encoding to use to convert text for the model + /// + [JsonConverter(typeof(EncodingConverter))] + public Encoding Encoding { get; set; } = Encoding.UTF8; + + /// + /// + /// + /// The model path. + [JsonConstructor] + public ModelParams(string modelPath) + { + ModelPath = modelPath; + } + + private ModelParams() + { + // This constructor (default parameterless constructor) is used by Newtonsoft to deserialize! + ModelPath = ""; + } + + /// + /// + /// + /// The model path. + /// Model context size (n_ctx) + /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) + /// Seed for the random number generator (seed) + /// Whether to use f16 instead of f32 for memory kv (memory_f16) + /// Whether to use mmap for faster loads (use_mmap) + /// Whether to use mlock to keep model in memory (use_mlock) + /// Thether to compute perplexity over the prompt (perplexity) + /// Lora adapter path (lora_adapter) + /// Base model path for the lora adapter (lora_base) + /// Number of threads (-1 = autodetect) (n_threads) + /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// Whether to convert eos to newline during the inference. + /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. + /// RoPE base frequency. + /// RoPE frequency scaling factor + /// Use experimental mul_mat_q kernels + /// The encoding to use to convert text for the model + [Obsolete("Use object initializer to set all optional parameters")] + public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, + int seed = 1337, bool useFp16Memory = true, + bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, + string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, + bool convertEosToNewLine = false, bool embeddingMode = false, + float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, + string encoding = "UTF-8") { ContextSize = contextSize; GpuLayerCount = gpuLayerCount; @@ -154,11 +170,27 @@ namespace LLama.Common BatchSize = batchSize; ConvertEosToNewLine = convertEosToNewLine; EmbeddingMode = embeddingMode; - GroupedQueryAttention = gqa; - RmsNormEpsilon = rmsNormEps; - RopeFrequencyBase = rope_freq_base; - RopeFrequencyScale = rope_freq_scale; - MulMatQ = muMatQ; - } + RopeFrequencyBase = ropeFrequencyBase; + RopeFrequencyScale = ropeFrequencyScale; + MulMatQ = mulMatQ; + Encoding = Encoding.GetEncoding(encoding); + } + } + + internal class EncodingConverter + : JsonConverter + { + public override Encoding? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var name = reader.GetString(); + if (name == null) + return null; + return Encoding.GetEncoding(name); + } + + public override void Write(Utf8JsonWriter writer, Encoding value, JsonSerializerOptions options) + { + writer.WriteStringValue(value.WebName); + } } } diff --git a/LLama/Exceptions/GrammarFormatExceptions.cs b/LLama/Exceptions/GrammarFormatExceptions.cs new file mode 100644 index 00000000..62b75224 --- /dev/null +++ b/LLama/Exceptions/GrammarFormatExceptions.cs @@ -0,0 +1,125 @@ +using System; + +namespace LLama.Exceptions; + +/// +/// Base class for all grammar exceptions +/// +public abstract class GrammarFormatException + : Exception +{ + internal GrammarFormatException(string message) + : base(message) + { + } +} + + +/// +/// An incorrect number of characters were encountered while parsing a hex literal +/// +public class GrammarUnexpectedHexCharsCount + : GrammarFormatException +{ + internal GrammarUnexpectedHexCharsCount(int size, string source) + : base($"Expecting {size} hex chars at {source}") + { + } +} + +/// +/// Failed to parse a "name" element when one was expected +/// +public class GrammarExpectedName + : GrammarFormatException +{ + internal GrammarExpectedName(string source) + : base($"Expecting name at {source}") + { + } +} + +/// +/// An unexpected character was encountered after an escape sequence +/// +public class GrammarUnknownEscapeCharacter + : GrammarFormatException +{ + internal GrammarUnknownEscapeCharacter(string source) + : base($"Unknown escape at {source}") + { + } +} + +/// +/// End-of-file was encountered while parsing +/// +public class GrammarUnexpectedEndOfInput + : GrammarFormatException +{ + internal GrammarUnexpectedEndOfInput() + : base($"Unexpected end of input") + { + } +} + +/// +/// A specified string was expected when parsing +/// +public class GrammarExpectedNext + : GrammarFormatException +{ + internal GrammarExpectedNext(string expected, string source) + : base($"Expected '{expected}' at {source}") + { + } +} + +/// +/// A specified character was expected to preceded another when parsing +/// +public class GrammarExpectedPrevious + : GrammarFormatException +{ + internal GrammarExpectedPrevious(string expected, string source) + : base($"Expecting preceding item to be '{expected}' at {source}") + { + } +} + + +/// +/// A CHAR_ALT was created without a preceding CHAR element +/// +public class GrammarUnexpectedCharAltElement + : GrammarFormatException +{ + internal GrammarUnexpectedCharAltElement(string ruleId, int index) + : base($"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{index}") + { + } +} + +/// +/// A CHAR_RNG was created without a preceding CHAR element +/// +public class GrammarUnexpectedCharRngElement + : GrammarFormatException +{ + internal GrammarUnexpectedCharRngElement(string ruleId, int index) + : base($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{index}") + { + } +} + +/// +/// An END was encountered before the last element +/// +public class GrammarUnexpectedEndElement + : GrammarFormatException +{ + internal GrammarUnexpectedEndElement(string ruleId, int index) + : base($"Unexpected LLamaGrammarElementType.END: {ruleId},{index}") + { + } +} \ No newline at end of file diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index 789f035a..6b839ff0 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -2,14 +2,16 @@ namespace LLama.Exceptions { - public class RuntimeError: Exception + public class RuntimeError + : Exception { public RuntimeError() { } - public RuntimeError(string message): base(message) + public RuntimeError(string message) + : base(message) { } diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs new file mode 100644 index 00000000..e5a27d6d --- /dev/null +++ b/LLama/Extensions/DictionaryExtensions.cs @@ -0,0 +1,14 @@ +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class DictionaryExtensions + { +#if NETSTANDARD2_0 + public static TValue GetValueOrDefault(this IReadOnlyDictionary dictionary, TKey key, TValue defaultValue) + { + return dictionary.TryGetValue(key, out var value) ? value : defaultValue; + } +#endif + } +} diff --git a/LLama/Extensions/IEnumerableExtensions.cs b/LLama/Extensions/IEnumerableExtensions.cs new file mode 100644 index 00000000..ebc234be --- /dev/null +++ b/LLama/Extensions/IEnumerableExtensions.cs @@ -0,0 +1,21 @@ +using System.Collections.Generic; +using System.Linq; + +namespace LLama.Extensions +{ + internal static class IEnumerableExtensions + { +#if NETSTANDARD2_0 + public static IEnumerable TakeLast(this IEnumerable source, int count) + { + var list = source.ToList(); + + if (count >= list.Count) + return list; + + list.RemoveRange(0, list.Count - count); + return list; + } +#endif + } +} diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 93b0f86e..c4cb1c62 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -39,8 +39,6 @@ namespace LLama.Extensions result.logits_all = @params.Perplexity; result.embedding = @params.EmbeddingMode; result.low_vram = @params.LowVram; - result.n_gqa = @params.GroupedQueryAttention; - result.rms_norm_eps = @params.RmsNormEpsilon; result.rope_freq_base = @params.RopeFrequencyBase; result.rope_freq_scale = @params.RopeFrequencyScale; result.mul_mat_q = @params.MulMatQ; diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs new file mode 100644 index 00000000..51b365be --- /dev/null +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class IReadOnlyListExtensions + { + public static int? IndexOf(this IReadOnlyList list, T item) + where T : IEquatable + { + for (var i = 0; i < list.Count; i++) + { + if (list[i].Equals(item)) + return i; + } + + return null; + } + } +} diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs new file mode 100644 index 00000000..c78d311c --- /dev/null +++ b/LLama/Extensions/ListExtensions.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class ListExtensions + { + public static void AddRangeSpan(this List list, ReadOnlySpan span) + { + for (var i = 0; i < span.Length; i++) + list.Add(span[i]); + } + } +} diff --git a/LLama/GlobalSuppressions.cs b/LLama/GlobalSuppressions.cs new file mode 100644 index 00000000..2053bc25 --- /dev/null +++ b/LLama/GlobalSuppressions.cs @@ -0,0 +1,10 @@ +// This file is used by Code Analysis to maintain SuppressMessage +// attributes that are applied to this project. +// Project-level suppressions either have no target or are given +// a specific target and scoped to a namespace, type, member, etc. + +using System.Diagnostics.CodeAnalysis; + +[assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] + +[assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")] diff --git a/LLama/Grammars/GBNFGrammarParser.cs b/LLama/Grammars/GBNFGrammarParser.cs new file mode 100644 index 00000000..cd5969e4 --- /dev/null +++ b/LLama/Grammars/GBNFGrammarParser.cs @@ -0,0 +1,406 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using LLama.Exceptions; +using LLama.Native; + +namespace LLama.Grammars +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + internal sealed class GBNFGrammarParser + { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + private uint DecodeUTF8(ref ReadOnlySpan src) + { + int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + + byte firstByte = src[0]; + byte highbits = (byte)(firstByte >> 4); + int len = lookup[highbits]; + byte mask = (byte)((1 << (8 - len)) - 1); + uint value = (uint)(firstByte & mask); + + int end = len; + int pos = 1; + + for (; pos < end && pos < src.Length; pos++) + { + value = (uint)((value << 6) + (src[pos] & 0x3F)); + } + + src = src.Slice(pos); + + return value; + } + + private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) + { + uint nextId = (uint)state.SymbolIds.Count; + string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); + + if (state.SymbolIds.TryGetValue(key, out uint existingId)) + { + return existingId; + } + else + { + state.SymbolIds[key] = nextId; + return nextId; + } + } + + private uint GenerateSymbolId(ParseState state, string baseName) + { + uint nextId = (uint)state.SymbolIds.Count; + string key = $"{baseName}_{nextId}"; + state.SymbolIds[key] = nextId; + return nextId; + } + + private void AddRule(ParseState state, uint ruleId, List rule) + { + while (state.Rules.Count <= ruleId) + { + state.Rules.Add(new List()); + } + + state.Rules[(int)ruleId] = rule; + } + + private bool IsWordChar(byte c) + { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + private uint ParseHex(ref ReadOnlySpan src, int size) + { + int pos = 0; + int end = size; + uint value = 0; + + for (; pos < end && pos < src.Length; pos++) + { + value <<= 4; + byte c = src[pos]; + if ('a' <= c && c <= 'f') + { + value += (uint)(c - 'a' + 10); + } + else if ('A' <= c && c <= 'F') + { + value += (uint)(c - 'A' + 10); + } + else if ('0' <= c && c <= '9') + { + value += (uint)(c - '0'); + } + else + { + break; + } + } + + if (pos != end) + { + throw new GrammarUnexpectedHexCharsCount(size, Encoding.UTF8.GetString(src.ToArray())); + } + src = src.Slice(pos); + return value; + } + + private ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) + { + int pos = 0; + while (pos < src.Length && + (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' || + (newlineOk && (src[pos] == '\r' || src[pos] == '\n')))) + { + if (src[pos] == '#') + { + while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n') + { + pos++; + } + } + else + { + pos++; + } + } + return src.Slice(pos); + } + + private ReadOnlySpan ParseName(ReadOnlySpan src) + { + int pos = 0; + while (pos < src.Length && IsWordChar(src[pos])) + { + pos++; + } + if (pos == 0) + { + throw new GrammarExpectedName(Encoding.UTF8.GetString(src.ToArray())); + } + return src.Slice(pos); + } + + private uint ParseChar(ref ReadOnlySpan src) + { + if (src[0] == '\\') + { + var chr = src[1]; + src = src.Slice(2); + switch (chr) + { + case (byte)'x': + return ParseHex(ref src, 2); + case (byte)'u': + return ParseHex(ref src, 4); + case (byte)'U': + return ParseHex(ref src, 8); + case (byte)'t': + return '\t'; + case (byte)'r': + return '\r'; + case (byte)'n': + return '\n'; + case (byte)'\\': + case (byte)'"': + case (byte)'[': + case (byte)']': + return chr; + default: + throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())); + } + } + else if (!src.IsEmpty) + { + return DecodeUTF8(ref src); + } + + throw new GrammarUnexpectedEndOfInput(); + } + + private ReadOnlySpan ParseSequence( + ParseState state, + ReadOnlySpan pos, + string ruleName, + List outElements, + bool isNested) + { + int lastSymStart = outElements.Count; + + while (!pos.IsEmpty) + { + if (pos[0] == '"') // literal string + { + pos = pos.Slice(1); + lastSymStart = outElements.Count; + + while (!pos.IsEmpty && pos[0] != '"') + { + var charPair = ParseChar(ref pos); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair)); + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (pos[0] == '[') // char range(s) + { + pos = pos.Slice(1); + var startType = LLamaGrammarElementType.CHAR; + + if (pos[0] == '^') + { + pos = pos.Slice(1); + startType = LLamaGrammarElementType.CHAR_NOT; + } + + lastSymStart = outElements.Count; + + while (!pos.IsEmpty && pos[0] != ']') + { + var charPair = ParseChar(ref pos); + var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; + + outElements.Add(new LLamaGrammarElement(type, charPair)); + + if (pos[0] == '-' && pos[1] != ']') + { + pos = pos.Slice(1); + var endCharPair = ParseChar(ref pos); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair)); + } + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (IsWordChar(pos[0])) // rule reference + { + var nameEnd = ParseName(pos); + uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); + pos = ParseSpace(nameEnd, isNested); + lastSymStart = outElements.Count; + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); + } + else if (pos[0] == '(') // grouping + { + // parse nested alternates into synthesized rule + pos = ParseSpace(pos.Slice(1), true); + uint subRuleId = GenerateSymbolId(state, ruleName); + pos = ParseAlternates(state, pos, ruleName, subRuleId, true); + lastSymStart = outElements.Count; + // output reference to synthesized rule + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); + if (pos[0] != ')') + throw new GrammarExpectedNext(")", Encoding.UTF8.GetString(pos.ToArray())); + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator + { + if (lastSymStart == outElements.Count) + throw new GrammarExpectedPrevious("*/+/?", Encoding.UTF8.GetString(pos.ToArray())); + + // apply transformation to previous symbol (lastSymStart to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint subRuleId = GenerateSymbolId(state, ruleName); + + List subRule = new List(); + + // add preceding symbol to generated rule + subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); + + if (pos[0] == '*' || pos[0] == '+') + { + // cause generated rule to recurse + subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); + } + + // mark start of alternate def + subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); + + if (pos[0] == '+') + { + // add preceding symbol as alternate only for '+' (otherwise empty) + subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); + } + + subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); + + AddRule(state, subRuleId, subRule); + + // in original rule, replace previous symbol with reference to generated rule + outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); + + pos = ParseSpace(pos.Slice(1), isNested); + + } + else + { + break; + } + } + + return pos; + } + + private ReadOnlySpan ParseAlternates( + ParseState state, + ReadOnlySpan src, + string ruleName, + uint ruleId, + bool isNested) + { + var rule = new List(); + ReadOnlySpan pos = ParseSequence(state, src, ruleName, rule, isNested); + + while (!pos.IsEmpty && pos[0] == '|') + { + rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); + pos = ParseSpace(pos.Slice(1), true); + pos = ParseSequence(state, pos, ruleName, rule, isNested); + } + + rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); + AddRule(state, ruleId, rule); + + return pos; + } + + private ReadOnlySpan ParseRule(ParseState state, ReadOnlySpan src) + { + ReadOnlySpan nameEnd = ParseName(src); + ReadOnlySpan pos = ParseSpace(nameEnd, false); + int nameLen = src.Length - nameEnd.Length; + uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); + string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) + throw new GrammarExpectedNext("::=", Encoding.UTF8.GetString(pos.ToArray())); + + pos = ParseSpace(pos.Slice(3), true); + + pos = ParseAlternates(state, pos, name, ruleId, false); + + if (!pos.IsEmpty && pos[0] == '\r') + { + pos = pos.Slice(pos[1] == '\n' ? 2 : 1); + } + else if (!pos.IsEmpty && pos[0] == '\n') + { + pos = pos.Slice(1); + } + else if (!pos.IsEmpty) + { + throw new GrammarExpectedNext("newline or EOF", Encoding.UTF8.GetString(pos.ToArray())); + } + return ParseSpace(pos, true); + } + + /// + /// Parse a string of GGML BNF + /// + /// The string to parse + /// The name of the root rule of this grammar + /// Thrown if input is malformed + /// A ParseState that can be converted into a grammar for sampling + public Grammar Parse(string input, string startRule) + { + var byteArray = Encoding.UTF8.GetBytes(input); + var state = new ParseState(); + var pos = ParseSpace(byteArray, true); + + while (!pos.IsEmpty) + { + pos = ParseRule(state, pos); + } + + var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key); + var rules = new List(); + for (var i = 0; i < state.Rules.Count; i++) + { + var elements = state.Rules[i]; + var name = names[(uint)i]; + rules.Add(new GrammarRule(name, elements)); + } + + var startRuleIndex = state.SymbolIds[startRule]; + return new Grammar(rules, startRuleIndex); + } + + private record ParseState + { + public SortedDictionary SymbolIds { get; } = new(); + public List> Rules { get; } = new(); + } + } +} diff --git a/LLama/Grammars/Grammar.cs b/LLama/Grammars/Grammar.cs new file mode 100644 index 00000000..dbb3658e --- /dev/null +++ b/LLama/Grammars/Grammar.cs @@ -0,0 +1,148 @@ +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Exceptions; +using LLama.Native; + +namespace LLama.Grammars +{ + /// + /// A grammar is a set of s for deciding which characters are valid next. Can be used to constrain + /// output to certain formats - e.g. force the model to output JSON + /// + public sealed class Grammar + { + /// + /// Index of the initial rule to start from + /// + public ulong StartRuleIndex { get; set; } + + /// + /// The rules which make up this grammar + /// + public IReadOnlyList Rules { get; } + + /// + /// Create a new grammar from a set of rules + /// + /// The rules which make up this grammar + /// Index of the initial rule to start from + /// + public Grammar(IReadOnlyList rules, ulong startRuleIndex) + { + if (startRuleIndex >= (uint)rules.Count) + throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules"); + + StartRuleIndex = startRuleIndex; + Rules = rules; + } + + /// + /// Create a `SafeLLamaGrammarHandle` instance to use for parsing + /// + /// + public SafeLLamaGrammarHandle CreateInstance() + { + return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex); + } + + /// + /// Parse a string of GGML BNF into a Grammar + /// + /// The string to parse + /// Name of the start rule of this grammar + /// Thrown if input is malformed + /// A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling + public static Grammar Parse(string gbnf, string startRule) + { + var parser = new GBNFGrammarParser(); + return parser.Parse(gbnf, startRule); + } + + /// + public override string ToString() + { + var builder = new StringBuilder(); + PrintGrammar(builder); + return builder.ToString(); + } + + private void PrintGrammar(StringBuilder output) + { + for (var i = 0; i < Rules.Count; i++) + PrintRule(output, Rules[i]); + } + + private void PrintRule(StringBuilder output, GrammarRule rule) + { + output.Append($"{rule.Name} ::= "); + + for (int i = 0, end = rule.Elements.Count - 1; i < end; i++) + { + var elem = rule.Elements[i]; + switch (elem.Type) + { + // GrammarRule has already verified that END is not being misused, no need to check again + case LLamaGrammarElementType.END: + break; + + case LLamaGrammarElementType.ALT: + output.Append("| "); + break; + + case LLamaGrammarElementType.RULE_REF: + output.Append($"{Rules[(int)elem.Value].Name} "); + break; + + case LLamaGrammarElementType.CHAR: + output.Append('['); + PrintGrammarChar(output, elem.Value); + break; + + case LLamaGrammarElementType.CHAR_NOT: + output.Append("[^"); + PrintGrammarChar(output, elem.Value); + break; + + case LLamaGrammarElementType.CHAR_RNG_UPPER: + output.Append('-'); + PrintGrammarChar(output, elem.Value); + break; + + case LLamaGrammarElementType.CHAR_ALT: + PrintGrammarChar(output, elem.Value); + break; + + } + + if (elem.IsCharElement()) + { + switch (rule.Elements[i + 1].Type) + { + case LLamaGrammarElementType.CHAR_ALT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + break; + default: + output.Append("] "); + break; + } + } + } + + output.AppendLine(); + } + + private static void PrintGrammarChar(StringBuilder output, uint c) + { + if (c >= 0x20 && c <= 0x7F) + { + output.Append((char)c); + } + else + { + // cop out of encoding UTF-8 + output.Append($""); + } + } + } +} diff --git a/LLama/Grammars/GrammarRule.cs b/LLama/Grammars/GrammarRule.cs new file mode 100644 index 00000000..c57b7084 --- /dev/null +++ b/LLama/Grammars/GrammarRule.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using LLama.Exceptions; +using LLama.Native; + +namespace LLama.Grammars +{ + /// + /// A single rule in a + /// + public sealed record GrammarRule + { + /// + /// Name of this rule + /// + public string Name { get; } + + /// + /// The elements of this grammar rule + /// + public IReadOnlyList Elements { get; } + + /// + /// Create a new GrammarRule containing the given elements + /// + /// + /// + /// + public GrammarRule(string name, IReadOnlyList elements) + { + Validate(elements, name); + + Name = name; + Elements = elements; + } + + private static void Validate(IReadOnlyList elements, string name) + { + if (elements.Count == 0) + throw new ArgumentException("Cannot create a GrammarRule with zero elements", nameof(elements)); + if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END) + throw new ArgumentException("Last grammar element must be END", nameof(elements)); + + for (var i = 0; i < elements.Count; i++) + { + switch (elements[i].Type) + { + case LLamaGrammarElementType.END: + if (i != elements.Count - 1) + throw new GrammarUnexpectedEndElement(name, i); + continue; + + case LLamaGrammarElementType.CHAR_RNG_UPPER: + if (i == 0 || !elements[i - 1].IsCharElement()) + throw new GrammarUnexpectedCharRngElement(name, i); + break; + case LLamaGrammarElementType.CHAR_ALT: + if (i == 0 || !elements[i - 1].IsCharElement()) + throw new GrammarUnexpectedCharAltElement(name, i); + break; + + case LLamaGrammarElementType.ALT: + case LLamaGrammarElementType.RULE_REF: + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: + break; + + default: + throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'", nameof(elements)); + } + } + } + } +} diff --git a/LLama/LLamaModel.cs b/LLama/LLamaContext.cs similarity index 61% rename from LLama/LLamaModel.cs rename to LLama/LLamaContext.cs index ccdc2678..3567feca 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaContext.cs @@ -9,34 +9,48 @@ using System.IO.MemoryMappedFiles; using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; -using Microsoft.Win32.SafeHandles; using LLama.Abstractions; namespace LLama { using llama_token = Int32; + /// - /// The abstraction of a LLama model, which holds the context in the native library. + /// A llama_context, which holds all the context required to interact with a model /// - public class LLamaModel: IDisposable + public sealed class LLamaContext + : IDisposable { - // TODO: expose more properties. - ILLamaLogger? _logger; - Encoding _encoding; - SafeLLamaContextHandle _ctx; + private readonly ILLamaLogger? _logger; + private readonly Encoding _encoding; + private readonly SafeLLamaContextHandle _ctx; + + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount => _ctx.VocabCount; + + /// + /// Total number of tokens in the context + /// + public int ContextSize => _ctx.ContextSize; + /// - /// The context size. + /// Dimension of embedding vectors /// - public int ContextSize { get; } + public int EmbeddingSize => _ctx.EmbeddingSize; + /// /// The model params set for this model. /// public IModelParams Params { get; set; } + /// - /// The native handle, which is used to be passed to the native APIs. Please avoid using it - /// unless you know what is the usage of the Native API. + /// The native handle, which is used to be passed to the native APIs /// + /// Be careful how you use this! public SafeLLamaContextHandle NativeHandle => _ctx; + /// /// The encoding set for this model to deal with text input. /// @@ -59,17 +73,59 @@ namespace LLama /// /// /// - /// Model params. - /// Encoding to deal with text input. + /// Model params. /// The logger. - public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null) + [Obsolete("Use the LLamaWeights.CreateContext instead")] + public LLamaContext(IModelParams @params, ILLamaLogger? logger = null) + { + Params = @params; + + _logger = logger; + _encoding = @params.Encoding; + + _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); + _ctx = Utils.InitLLamaContextFromModelParams(Params); + } + + internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILLamaLogger? logger = null) { + Params = @params; + _logger = logger; - this.Params = Params; - _encoding = Encoding.GetEncoding(encoding); - _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); - _ctx = Utils.InitLLamaContextFromModelParams(this.Params); - ContextSize = NativeApi.llama_n_ctx(_ctx); + _encoding = @params.Encoding; + _ctx = nativeContext; + } + + /// + /// Create a new LLamaContext for the given LLamaWeights + /// + /// + /// + /// + /// + public LLamaContext(LLamaWeights model, IModelParams @params, ILLamaLogger? logger = null) + { + if (model.NativeHandle.IsClosed) + throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); + + Params = @params; + + _logger = logger; + _encoding = @params.Encoding; + + using var pin = @params.ToLlamaContextParams(out var lparams); + _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); + } + + /// + /// Create a copy of the current state of this context + /// + /// + public LLamaContext Clone() + { + using var pin = Params.ToLlamaContextParams(out var lparams); + var clone = _ctx.Clone(lparams); + return new LLamaContext(clone, Params); } /// @@ -90,9 +146,10 @@ namespace LLama /// public string DeTokenize(IEnumerable tokens) { - StringBuilder sb = new(); - foreach(var token in tokens) - sb.Append(_ctx.TokenToString(token, _encoding)); + var sb = new StringBuilder(); + foreach (var token in tokens) + _ctx.TokenToString(token, _encoding, sb); + return sb.ToString(); } @@ -147,7 +204,7 @@ namespace LLama /// public State GetState() { - var stateSize = NativeApi.llama_get_state_size(_ctx); + var stateSize = _ctx.GetStateSize(); unsafe { @@ -156,15 +213,17 @@ namespace LLama try { // Copy the state data into "big memory", discover the actual size required - var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); + var actualSize = _ctx.GetState(bigMemory, stateSize); + + // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying + if (actualSize >= stateSize - 1_000_000) + return new State(bigMemory); - // Allocate a smaller buffer + // Allocate a smaller buffer which is exactly the right size smallMemory = Marshal.AllocHGlobal((nint)actualSize); // Copy into the smaller buffer and free the large one to save excess memory usage Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); - Marshal.FreeHGlobal(bigMemory); - bigMemory = IntPtr.Zero; return new State(smallMemory); } @@ -224,7 +283,7 @@ namespace LLama { unsafe { - NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); + _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer()); } } @@ -241,11 +300,19 @@ namespace LLama /// /// /// + /// /// public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, - float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) + float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, + SafeLLamaGrammarHandle? grammar = null) { llama_token id; + + if (grammar != null) + { + SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); + } + if (temperature <= 0) { // Greedy sampling @@ -279,6 +346,12 @@ namespace LLama } mirostat_mu = mu; } + + if (grammar != null) + { + NativeApi.llama_grammar_accept_token(_ctx, grammar, id); + } + return id; } @@ -297,41 +370,89 @@ namespace LLama int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { - var n_vocab = _ctx.VocabCount; var logits = _ctx.GetLogits(); // Apply params.logit_bias map - if(logitBias is not null) + if (logitBias is not null) { foreach (var (key, value) in logitBias) - { logits[key] += value; - } } - var candidates = new LLamaTokenData[n_vocab]; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) - candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); - LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); - - // Apply penalties - float nl_logit = logits[NativeApi.llama_token_nl()]; - int lastTokensCount = lastTokens.Count(); - var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize); - SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, - lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), - (ulong)last_n_repeat, repeatPenalty); - SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, - lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), - (ulong)last_n_repeat, alphaFrequency, alphaPresence); + // Save the newline logit value + var nl_token = NativeApi.llama_token_nl(_ctx); + var nl_logit = logits[nl_token]; + + // Convert logits into token candidates + var candidates_p = LLamaTokenDataArray.Create(logits); + + // Extract most recently returned tokens + var last_n_repeat = Math.Min(ContextSize, repeatLastTokensCount); + var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); + + // Apply penalties to candidates + SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty); + SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence); + + // Restore newline token logit value if necessary if (!penalizeNL) { - logits[NativeApi.llama_token_nl()] = nl_logit; + var candidatesSpan = candidates_p.data.Span; + for (var i = 0; i < candidates_p.data.Length; i++) + { + ref var item = ref candidatesSpan[i]; + if (item.id == nl_token) + item.logit = nl_logit; + } + candidates_p.sorted = false; } return candidates_p; } + #region eval overloads + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(llama_token[] tokens, llama_token pastTokensCount) + { + return Eval(tokens.AsSpan(), pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(List tokens, llama_token pastTokensCount) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(tokens); + return Eval(span, pastTokensCount); +#else + // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of + // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't + // avoid the copying. + + var rented = System.Buffers.ArrayPool.Shared.Rent(tokens.Count); + try + { + tokens.CopyTo(rented, 0); + return Eval(rented, pastTokensCount); + } + finally + { + System.Buffers.ArrayPool.Shared.Return(rented); + } +#endif + } + /// /// /// @@ -339,20 +460,32 @@ namespace LLama /// /// The updated `pastTokensCount`. /// - public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) + public int Eval(ReadOnlyMemory tokens, llama_token pastTokensCount) { - int total = tokens.Length; - for(int i = 0; i < total; i += Params.BatchSize) + return Eval(tokens.Span, pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlySpan tokens, llama_token pastTokensCount) + { + var total = tokens.Length; + for(var i = 0; i < total; i += Params.BatchSize) { - int n_eval = total - i; - if(n_eval > Params.BatchSize) + var n_eval = total - i; + if (n_eval > Params.BatchSize) { n_eval = Params.BatchSize; } - if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) + if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) { - _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); + _logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); } @@ -360,16 +493,26 @@ namespace LLama } return pastTokensCount; } +#endregion - // TODO: add comment internal IEnumerable GenerateResult(IEnumerable ids) { foreach(var id in ids) yield return _ctx.TokenToString(id, _encoding); } + /// + /// Convert a token into a string + /// + /// + /// + public string TokenToString(llama_token token) + { + return NativeHandle.TokenToString(token, Encoding); + } + /// - public virtual void Dispose() + public void Dispose() { _ctx.Dispose(); } @@ -378,12 +521,11 @@ namespace LLama /// The state of this model, which can be reloaded later /// public class State - : SafeHandleZeroOrMinusOneIsInvalid + : SafeLLamaHandleBase { internal State(IntPtr memory) - : base(true) + : base(memory) { - SetHandle(memory); } /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index a74f11ee..64c17539 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,9 +1,6 @@ using LLama.Native; using System; -using System.Collections.Generic; -using System.Text; using LLama.Exceptions; -using System.Linq; using LLama.Abstractions; namespace LLama @@ -11,18 +8,15 @@ namespace LLama /// /// The embedder for LLama, which supports getting embeddings from text. /// - public class LLamaEmbedder : IDisposable + public sealed class LLamaEmbedder + : IDisposable { - SafeLLamaContextHandle _ctx; + private readonly LLamaContext _ctx; /// - /// Warning: must ensure the original model has params.embedding = true; + /// Dimension of embedding vectors /// - /// - internal LLamaEmbedder(SafeLLamaContextHandle ctx) - { - _ctx = ctx; - } + public int EmbeddingSize => _ctx.EmbeddingSize; /// /// @@ -31,52 +25,67 @@ namespace LLama public LLamaEmbedder(IModelParams @params) { @params.EmbeddingMode = true; - _ctx = Utils.InitLLamaContextFromModelParams(@params); + using var weights = LLamaWeights.LoadFromFile(@params); + _ctx = weights.CreateContext(@params); + } + + public LLamaEmbedder(LLamaWeights weights, IModelParams @params) + { + _ctx = weights.CreateContext(@params); } /// /// Get the embeddings of the text. /// /// - /// Threads used for inference. + /// unused /// Add bos to the text. - /// + /// unused /// /// - public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") + [Obsolete("'threads' and 'encoding' parameters are no longer used")] + // ReSharper disable once MethodOverloadWithOptionalParameter + public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") { - if (threads == -1) - { - threads = Math.Max(Environment.ProcessorCount / 2, 1); - } - int n_past = 0; - if (addBos) - { - text = text.Insert(0, " "); - } + return GetEmbeddings(text, addBos); + } - var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); + /// + /// Get the embeddings of the text. + /// + /// + /// + /// + public float[] GetEmbeddings(string text) + { + return GetEmbeddings(text, true); + } + + /// + /// Get the embeddings of the text. + /// + /// + /// Add bos to the text. + /// + /// + public float[] GetEmbeddings(string text, bool addBos) + { + + var embed_inp_array = _ctx.Tokenize(text, addBos); // TODO(Rinne): deal with log of prompt if (embed_inp_array.Length > 0) - { - if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0) - { - throw new RuntimeError("Failed to eval."); - } - } + _ctx.Eval(embed_inp_array, 0); - int n_embed = NativeApi.llama_n_embd(_ctx); - var embeddings = NativeApi.llama_get_embeddings(_ctx); - if (embeddings == null) + unsafe { - return Array.Empty(); + var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle); + if (embeddings == null) + return Array.Empty(); + + return new Span(embeddings, EmbeddingSize).ToArray(); } - var span = new Span(embeddings, n_embed); - float[] res = new float[n_embed]; - span.CopyTo(res.AsSpan()); - return res; } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 2caaa8e5..73dd439c 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -18,10 +18,6 @@ namespace LLama /// public abstract class StatefulExecutorBase : ILLamaExecutor { - /// - /// The loaded model for this executor. - /// - protected readonly LLamaModel _model; /// /// The logger used by this executor. /// @@ -63,9 +59,9 @@ namespace LLama /// protected FixedSizeQueue _last_n_tokens; /// - /// The mode used by the executor. + /// The context used by the executor. /// - public LLamaModel Model => _model; + public LLamaContext Context { get; } /// /// Current "mu" value for mirostat sampling @@ -75,16 +71,16 @@ namespace LLama /// /// /// - /// + /// /// - protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null) + protected StatefulExecutorBase(LLamaContext context, ILLamaLogger? logger = null) { - _model = model; + Context = context; _logger = logger; _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; - _last_n_tokens = new FixedSizeQueue(_model.ContextSize).FillWith(0); + _last_n_tokens = new FixedSizeQueue(Context.ContextSize).FillWith(0); } /// @@ -104,9 +100,9 @@ namespace LLama if (File.Exists(filename)) { _logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info); - llama_token[] session_tokens = new llama_token[_model.ContextSize]; + llama_token[] session_tokens = new llama_token[Context.ContextSize]; ulong n_token_count_out = 0; - if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out)) + if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out)) { _logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error); throw new RuntimeError($"Failed to load session file {_pathSession}"); @@ -156,7 +152,7 @@ namespace LLama public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); - NativeApi.llama_save_session_file(_model.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); + NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); } /// @@ -173,7 +169,7 @@ namespace LLama _pastTokensCount = Math.Max(1, tokensToKeep); // insert n_left/2 tokens at the start of embed from last_n_tokens - _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(_model.ContextSize - n_left / 2 - _embeds.Count)); + _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(Context.ContextSize - n_left / 2 - _embeds.Count)); // stop saving session if we run out of context _pathSession = string.Empty; @@ -270,10 +266,7 @@ namespace LLama public virtual IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - if (inferenceParams is null) - { - inferenceParams = new InferenceParams(); - } + inferenceParams ??= new InferenceParams(); InferStateArgs args = new InferStateArgs() { @@ -296,7 +289,7 @@ namespace LLama if (args.ReturnValue) { - foreach (var item in _model.GenerateResult(_embeds)) + foreach (var item in Context.GenerateResult(_embeds)) { yield return item; } @@ -374,7 +367,7 @@ namespace LLama public int MatchingSessionTokensCount { get; set; } [JsonPropertyName("path_session")] - public string SessionFilePath { get; set; } + public string? SessionFilePath { get; set; } [JsonPropertyName("embd")] public List Embeds { get; set; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 7a065ce5..a7d53cc8 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -24,14 +25,14 @@ namespace LLama /// /// /// - /// + /// /// /// - public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", - string instructionSuffix = "\n\n### Response:\n\n") : base(model) + public InstructExecutor(LLamaContext context, string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n") : base(context) { - _inp_pfx = _model.Tokenize(instructionPrefix, true); - _inp_sfx = _model.Tokenize(instructionSuffix, false); + _inp_pfx = Context.Tokenize(instructionPrefix, true); + _inp_sfx = Context.Tokenize(instructionSuffix, false); _instructionPrefix = instructionPrefix; } @@ -84,16 +85,16 @@ namespace LLama /// public override void SaveState(string filename) { - InstructExecutorState state = GetStateData() as InstructExecutorState; - using (FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) + var state = (InstructExecutorState)GetStateData(); + using (var fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) { - JsonSerializer.Serialize(fs, state); + JsonSerializer.Serialize(fs, state); } } /// public override void LoadState(string filename) { - using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) + using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = JsonSerializer.Deserialize(fs); LoadState(state); @@ -108,16 +109,12 @@ namespace LLama /// protected override void PreprocessInputs(string text, InferStateArgs args) { - if(args.Antiprompts is null) - { - args.Antiprompts = new List(); - } + args.Antiprompts ??= new List(); args.Antiprompts.Add(_instructionPrefix); if (_is_prompt_run) { // When running the first input (prompt) in inteactive mode, we should specially process it. - text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); + _embed_inps = Context.Tokenize(text, true).ToList(); } else { @@ -128,7 +125,7 @@ namespace LLama _consumedTokensCount = _embed_inps.Count; _embed_inps.AddRange(_inp_pfx); - var line_inp = _model.Tokenize(text, false); + var line_inp = Context.Tokenize(text, false); _embed_inps.AddRange(line_inp); _embed_inps.AddRange(_inp_sfx); @@ -144,9 +141,10 @@ namespace LLama { if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - string last_output = ""; - foreach (var id in _last_n_tokens) - last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); + var last_output_builder = new StringBuilder(); + foreach (var token in _last_n_tokens) + Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); + var last_output = last_output_builder.ToString(); foreach (var antiprompt in args.Antiprompts) { @@ -160,12 +158,12 @@ namespace LLama if (_pastTokensCount > 0 && args.WaitForInput) { - extraOutputs = new string[] { "\n> " }; + extraOutputs = new[] { "\n> " }; return true; } } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) { args.WaitForInput = true; } @@ -183,13 +181,13 @@ namespace LLama if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + if (_pastTokensCount + _embeds.Count > Context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -202,7 +200,7 @@ namespace LLama if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) @@ -211,13 +209,14 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; - var id = _model.Sample( + var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, + inferenceParams.Grammar ); MirostatMu = mu; @@ -235,7 +234,7 @@ namespace LLama _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) + if (_embeds.Count >= Context.Params.BatchSize) { break; } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 52d8d3bc..38d6b443 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Text; namespace LLama { @@ -22,10 +23,10 @@ namespace LLama /// /// /// - /// - public InteractiveExecutor(LLamaModel model) : base(model) + /// + public InteractiveExecutor(LLamaContext context) : base(context) { - _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); + _llama_token_newline = new [] { NativeApi.llama_token_nl(Context.NativeHandle) }; } /// @@ -72,10 +73,10 @@ namespace LLama /// public override void SaveState(string filename) { - InteractiveExecutorState state = GetStateData() as InteractiveExecutorState; + InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) { - JsonSerializer.Serialize(fs, state); + JsonSerializer.Serialize(fs, state); } } /// @@ -103,8 +104,7 @@ namespace LLama if (_is_prompt_run) { // When running the first input (prompt) in inteactive mode, we should specially process it. - text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); + _embed_inps = Context.Tokenize(text, true).ToList(); } else { @@ -112,7 +112,7 @@ namespace LLama { text += "\n"; } - var line_inp = _model.Tokenize(text, false); + var line_inp = Context.Tokenize(text, false); _embed_inps.AddRange(line_inp); args.RemainedTokens -= line_inp.Length; } @@ -121,7 +121,9 @@ namespace LLama /// /// Return whether to break the generation. /// + /// /// + /// /// protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) { @@ -130,11 +132,10 @@ namespace LLama { if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - string last_output = ""; - foreach (var id in _last_n_tokens) - { - last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); - } + var last_output_builder = new StringBuilder(); + foreach (var token in _last_n_tokens) + Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); + var last_output = last_output_builder.ToString(); foreach (var antiprompt in args.Antiprompts) { @@ -152,9 +153,9 @@ namespace LLama } } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) { - extraOutputs = new string[] { " [end of text]\n" }; + extraOutputs = new[] { " [end of text]\n" }; return true; } @@ -172,13 +173,13 @@ namespace LLama if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + if (_pastTokensCount + _embeds.Count > Context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -191,7 +192,7 @@ namespace LLama if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) @@ -200,24 +201,25 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; - var id = _model.Sample( + var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, + inferenceParams.Grammar ); MirostatMu = mu; _last_n_tokens.Enqueue(id); - if (id == NativeApi.llama_token_eos()) + if (id == NativeApi.llama_token_eos(Context.NativeHandle)) { id = _llama_token_newline.First(); if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false); + var first_antiprompt = Context.Tokenize(args.Antiprompts[0], false); _embed_inps.AddRange(first_antiprompt); } } @@ -234,7 +236,7 @@ namespace LLama _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) + if (_embeds.Count >= Context.Params.BatchSize) { break; } diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index c3ff5613..f1d89586 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -1,8 +1,6 @@ using LLama.Native; using System; using System.Collections.Generic; -using System.Linq; -using System.Text; namespace LLama { @@ -36,8 +34,7 @@ namespace LLama quantizeParams.nthread = nthread; quantizeParams.allow_requantize = allowRequantize; quantizeParams.quantize_output_tensor = quantizeOutputTensor; - LLamaModelQuantizeParams* p = &quantizeParams; - return NativeApi.llama_model_quantize(srcFileName, dstFilename, p) == 0; + return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; } /// @@ -57,42 +54,71 @@ namespace LLama return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread, allowRequantize, quantizeOutputTensor); } - private static bool ValidateFtype(string ftype) - { - return new string[] { "q4_0", "q4_1", "q5_0", "q5_1", "q8_0" }.Contains(ftype); - } - private static bool ValidateFtype(LLamaFtype ftype) { - return ftype is LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1 - or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0; - } + // Validation copies from here: + // https://github.com/ggerganov/llama.cpp/blob/e59fcb2bc129881f4a269fee748fb38bce0a64de/llama.cpp#L2960 - private static string FtypeToString(LLamaFtype ftype) - { - return ftype switch + switch (ftype) { - LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0 => "q4_0", - LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1 => "q4_1", - LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0 => "q5_0", - LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1 => "q5_1", - LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0 => "q8_0", - _ => throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " + - $"to perform quantization.") - }; + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_F16: + case LLamaFtype.LLAMA_FTYPE_ALL_F32: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_K_M: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_K_M: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q6_K: + return true; + + case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: + default: + return false; + } } + /// + /// Parse a string into a LLamaFtype. This is a "relaxed" parsing, which allows any string which is contained within + /// the enum name to be used. + /// + /// For example "Q5_K_M" will convert to "LLAMA_FTYPE_MOSTLY_Q5_K_M" + /// + /// + /// + /// private static LLamaFtype StringToFtype(string str) { - return str switch + // Find all variants which contain the input string + var matches = new List(); + foreach (LLamaFtype ftype in Enum.GetValues(typeof(LLamaFtype))) { - "q4_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0, - "q4_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1, - "q5_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0, - "q5_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1, - "q8_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0, - _ => throw new ArgumentException($"Invalid ftype {str} to quantize.") - }; + var name = Enum.GetName(typeof(LLamaFtype), ftype); + + // Note: this is using "IndexOf" instead of "Contains" to be compatible with netstandard2.0 +#pragma warning disable CA2249 + if (name != null && name.IndexOf(str, StringComparison.OrdinalIgnoreCase) >= 0) + matches.Add(ftype); +#pragma warning restore CA2249 + } + + // If there was just one match, success! + if (matches.Count == 1) + return matches[0]; + + // If none matched throw a generic error + if (matches.Count == 0) + throw new ArgumentException($"Unknown ftype \"{str}\" for quantization."); + + // There were several matches, throw an error asking the user to be more specific + throw new ArgumentException($"\"{str}\" matches multiple potential ftypes: {string.Join(",", matches)}"); } } } diff --git a/LLama/LLamaSharp.Runtime.targets b/LLama/LLamaSharp.Runtime.targets index e83b11ac..df079ba3 100644 --- a/LLama/LLamaSharp.Runtime.targets +++ b/LLama/LLamaSharp.Runtime.targets @@ -32,11 +32,11 @@ libllama.dylib - PreserveNewest + None libllama-metal.dylib - PreserveNewest + None ggml-metal.metal diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index be32a5af..d3f0c0e2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -1,125 +1,159 @@ using LLama.Abstractions; using LLama.Common; -using LLama.Native; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; namespace LLama { using llama_token = Int32; + /// /// This executor infer the input as one-time job. Previous inputs won't impact on the /// response to current input. /// - public class StatelessExecutor : ILLamaExecutor + public class StatelessExecutor + : ILLamaExecutor { - private LLamaModel _model; - private LLamaModel.State _originalState; + private readonly LLamaWeights _weights; + private readonly IModelParams _params; + + /// + /// The context used by the executor when running the inference. + /// + public LLamaContext Context { get; private set; } + /// - /// The mode used by the executor when running the inference. + /// Create a new stateless executor which will use the given model /// - public LLamaModel Model => _model; + /// + /// + public StatelessExecutor(LLamaWeights weights, IModelParams @params) + { + _weights = weights; + _params = @params; + + Context = _weights.CreateContext(_params); + Context.Dispose(); + } + /// - /// + /// Create a new stateless executor which will use the model used to create the given context /// - /// The LLama model. - public StatelessExecutor(LLamaModel model) + /// + [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] + public StatelessExecutor(LLamaContext context) { - _model = model; - - var tokens = model.Tokenize(" ", true).ToArray(); - _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads); - _originalState = model.GetState(); + _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); + _params = context.Params; + + Context = _weights.CreateContext(_params); + Context.Dispose(); } /// public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { - cancellationToken.ThrowIfCancellationRequested(); - int n_past = 1; - if(inferenceParams is null) - { - inferenceParams = new InferenceParams(); - } - List lastTokens = new(inferenceParams.RepeatLastTokensCount); - for(int i = 0; i < lastTokens.Count; i++) + using var context = _weights.CreateContext(_params); + Context = context; + + if (!Context.NativeHandle.IsClosed) + Context.Dispose(); + Context = _weights.CreateContext(Context.Params); + + if (inferenceParams != null) { - lastTokens[i] = 0; + if (inferenceParams.TokensKeep > Context.ContextSize) + throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); } - List tokens = _model.Tokenize(text, true).ToList(); - int n_prompt_tokens = tokens.Count; - _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads); + cancellationToken.ThrowIfCancellationRequested(); + + var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); + var n_past = 1; + inferenceParams ??= new InferenceParams(); + + var lastTokens = new List(inferenceParams.RepeatLastTokensCount); + for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) + lastTokens.Add(0); + + var tokens = Context.Tokenize(text).ToList(); + var n_prompt_tokens = tokens.Count; + + Context.Eval(tokens, n_past); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; var mu = (float?)null; - int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; - for(int i = 0; i < max_tokens; i++) + var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; + for(var i = 0; i < max_tokens; i++) { if (cancellationToken.IsCancellationRequested) - { - _model.LoadState(_originalState); break; - } - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; - var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; + + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); + var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); - string response = _model.NativeHandle.TokenToString(id, _model.Encoding); + var response = Context.TokenToString(id); yield return response; tokens.Clear(); tokens.Add(id); - if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0) - { - string last_output = ""; - foreach (var token in lastTokens) - { - last_output += _model.NativeHandle.TokenToString(token, _model.Encoding); - } - - bool should_break = false; - foreach (var antiprompt in inferenceParams.AntiPrompts) - { - if (last_output.EndsWith(antiprompt)) - { - should_break = true; - break; - } - } - if (should_break) - { - break; - } - } + if (EndsWithAntiprompt(lastTokens, antiprompts)) + break; // when run out of context - if (n_past + tokens.Count > _model.ContextSize) + // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 + if (n_past + tokens.Count > Context.ContextSize) { - int n_left = n_past - inferenceParams.TokensKeep; + var n_left = n_past - inferenceParams.TokensKeep; n_past = Math.Max(1, inferenceParams.TokensKeep); - // insert n_left/2 tokens at the start of embed from last_n_tokens - tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_model.ContextSize - n_left / 2 - tokens.Count)); + tokens.Clear(); + tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); } - n_past = _model.Eval(tokens.ToArray(), n_past); + n_past = Context.Eval(tokens, n_past); + } + } + + /// + /// Check if the given tokens list ends with any of the antiprompts + /// + /// + /// + /// + private bool EndsWithAntiprompt(IReadOnlyList tokens, IReadOnlyList antiprompts) + { + if (antiprompts.Count == 0 || tokens.Count == 0) + return false; + + var builder = new StringBuilder(); + foreach (var token in tokens) + builder.Append(Context.TokenToString(token)); + + var last_output = builder.ToString(); + + foreach (var antiprompt in antiprompts) + { + if (last_output.EndsWith(antiprompt)) + return true; } - _model.LoadState(_originalState); + return false; } /// diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs new file mode 100644 index 00000000..1b067f1b --- /dev/null +++ b/LLama/LLamaWeights.cs @@ -0,0 +1,81 @@ +using System; +using System.Text; +using LLama.Abstractions; +using LLama.Extensions; +using LLama.Native; + +namespace LLama +{ + /// + /// A set of model weights, loaded into memory. + /// + public sealed class LLamaWeights + : IDisposable + { + private readonly SafeLlamaModelHandle _weights; + + /// + /// The native handle, which is used in the native APIs + /// + /// Be careful how you use this! + public SafeLlamaModelHandle NativeHandle => _weights; + + /// + /// Encoding to use to convert text into bytes for the model + /// + public Encoding Encoding { get; } + + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount => NativeHandle.VocabCount; + + /// + /// Total number of tokens in the context + /// + public int ContextSize => NativeHandle.ContextSize; + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingSize => NativeHandle.EmbeddingSize; + + internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) + { + _weights = weights; + Encoding = encoding; + } + + /// + /// Load weights into memory + /// + /// + /// + public static LLamaWeights LoadFromFile(IModelParams @params) + { + using var pin = @params.ToLlamaContextParams(out var lparams); + var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + + if (!string.IsNullOrEmpty(@params.LoraAdapter)) + weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + + return new LLamaWeights(weights, @params.Encoding); + } + + /// + public void Dispose() + { + _weights.Dispose(); + } + + /// + /// Create a llama_context using this model + /// + /// + /// + public LLamaContext CreateContext(IModelParams @params) + { + return new LLamaContext(this, @params); + } + } +} diff --git a/LLama/Native/GgmlInitParams.cs b/LLama/Native/GgmlInitParams.cs deleted file mode 100644 index 834ceab9..00000000 --- a/LLama/Native/GgmlInitParams.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; - -namespace LLama.Native -{ - internal struct GgmlInitParams - { - public ulong mem_size; - public IntPtr mem_buffer; - [MarshalAs(UnmanagedType.I1)] - public bool no_alloc; - } -} diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 0ede4e76..200301da 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -1,11 +1,18 @@ using System; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace LLama.Native { + /// + /// Called by llama.cpp with a progress value between 0 and 1 + /// + /// + /// public delegate void LlamaProgressCallback(float progress, IntPtr ctx); + + /// + /// A C# representation of the llama.cpp `llama_context_params` struct + /// [StructLayout(LayoutKind.Sequential)] public struct LLamaContextParams { @@ -24,16 +31,6 @@ namespace LLama.Native /// public int n_batch; - /// - /// grouped-query attention (TEMP - will be moved to model hparams) - /// - public int n_gqa; - - /// - /// rms norm epsilon (TEMP - will be moved to model hparams) - /// - public float rms_norm_eps; - /// /// number of layers to store in VRAM /// @@ -49,7 +46,6 @@ namespace LLama.Native /// public nint tensor_split; - /// /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 /// RoPE base frequency @@ -72,53 +68,85 @@ namespace LLama.Native /// public IntPtr progress_callback_user_data; - /// /// if true, reduce VRAM usage at the cost of performance /// - [MarshalAs(UnmanagedType.I1)] - public bool low_vram; + public bool low_vram + { + readonly get => Convert.ToBoolean(_low_vram); + set => _low_vram = Convert.ToSByte(value); + } + private sbyte _low_vram; /// /// if true, use experimental mul_mat_q kernels /// - [MarshalAs(UnmanagedType.I1)] public bool mul_mat_q; + public bool mul_mat_q + { + readonly get => Convert.ToBoolean(_mul_mat_q); + set => _mul_mat_q = Convert.ToSByte(value); + } + private sbyte _mul_mat_q; /// /// use fp16 for KV cache /// - [MarshalAs(UnmanagedType.I1)] - public bool f16_kv; + public bool f16_kv + { + readonly get => Convert.ToBoolean(_f16_kv); + set => _f16_kv = Convert.ToSByte(value); + } + private sbyte _f16_kv; /// /// the llama_eval() call computes all logits, not just the last one /// - [MarshalAs(UnmanagedType.I1)] - public bool logits_all; + public bool logits_all + { + readonly get => Convert.ToBoolean(_logits_all); + set => _logits_all = Convert.ToSByte(value); + } + private sbyte _logits_all; /// /// only load the vocabulary, no weights /// - [MarshalAs(UnmanagedType.I1)] - public bool vocab_only; + public bool vocab_only + { + readonly get => Convert.ToBoolean(_vocab_only); + set => _vocab_only = Convert.ToSByte(value); + } + private sbyte _vocab_only; /// /// use mmap if possible /// - [MarshalAs(UnmanagedType.I1)] - public bool use_mmap; + public bool use_mmap + { + readonly get => Convert.ToBoolean(_use_mmap); + set => _use_mmap = Convert.ToSByte(value); + } + private sbyte _use_mmap; /// /// force system to keep model in RAM /// - [MarshalAs(UnmanagedType.I1)] - public bool use_mlock; + public bool use_mlock + { + readonly get => Convert.ToBoolean(_use_mlock); + set => _use_mlock = Convert.ToSByte(value); + } + private sbyte _use_mlock; /// /// embedding mode only /// - [MarshalAs(UnmanagedType.I1)] - public bool embedding; + public bool embedding + { + readonly get => Convert.ToBoolean(_embedding); + set => _embedding = Convert.ToSByte(value); + } + private sbyte _embedding; } } diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 41159ee2..0fa0fbe9 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -1,29 +1,114 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama.Native +namespace LLama.Native { + /// + /// Supported model file types + /// public enum LLamaFtype { + /// + /// All f32 + /// + /// Benchmark@7B: 26GB LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + + /// + /// Mostly f16 + /// + /// Benchmark@7B: 13GB + LLAMA_FTYPE_MOSTLY_F16 = 1, + + /// + /// Mostly 8 bit + /// + /// Benchmark@7B: 6.7GB, +0.0004ppl + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, + + /// + /// Mostly 4 bit + /// + /// Benchmark@7B: 3.50GB, +0.2499 ppl + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, + + /// + /// Mostly 4 bit + /// + /// Benchmark@7B: 3.90GB, +0.1846 ppl + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, + + /// + /// Mostly 4 bit, tok_embeddings.weight and output.weight are f16 + /// + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, + + /// + /// Mostly 5 bit + /// + /// Benchmark@7B: 4.30GB @ 7B tokens, +0.0796 ppl + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, + + /// + /// Mostly 5 bit + /// + /// Benchmark@7B: 4.70GB, +0.0415 ppl + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, + + /// + /// K-Quant 2 bit + /// + /// Benchmark@7B: 2.67GB @ 7N parameters, +0.8698 ppl + LLAMA_FTYPE_MOSTLY_Q2_K = 10, + + /// + /// K-Quant 3 bit (Small) + /// + /// Benchmark@7B: 2.75GB, +0.5505 ppl + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, + + /// + /// K-Quant 3 bit (Medium) + /// + /// Benchmark@7B: 3.06GB, +0.2437 ppl + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, + + /// + /// K-Quant 3 bit (Large) + /// + /// Benchmark@7B: 3.35GB, +0.1803 ppl + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, + + /// + /// K-Quant 4 bit (Small) + /// + /// Benchmark@7B: 3.56GB, +0.1149 ppl + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, + + /// + /// K-Quant 4 bit (Medium) + /// + /// Benchmark@7B: 3.80GB, +0.0535 ppl + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, + + /// + /// K-Quant 5 bit (Small) + /// + /// Benchmark@7B: 4.33GB, +0.0353 ppl + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, + + /// + /// K-Quant 5 bit (Medium) + /// + /// Benchmark@7B: 4.45GB, +0.0142 ppl + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, + + /// + /// K-Quant 6 bit + /// + /// Benchmark@7B: 5.15GB, +0.0044 ppl + LLAMA_FTYPE_MOSTLY_Q6_K = 18, + + /// + /// File type was not specified + /// + LLAMA_FTYPE_GUESSED = 1024 } } diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs new file mode 100644 index 00000000..688f5ccb --- /dev/null +++ b/LLama/Native/LLamaGrammarElement.cs @@ -0,0 +1,124 @@ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + /// + /// grammar element type + /// + public enum LLamaGrammarElementType + { + /// + /// end of rule definition + /// + END = 0, + + /// + /// start of alternate definition for rule + /// + ALT = 1, + + /// + /// non-terminal element: reference to rule + /// + RULE_REF = 2, + + /// + /// terminal element: character (code point) + /// + CHAR = 3, + + /// + /// inverse char(s) ([^a], [^a-b] [^abc]) + /// + CHAR_NOT = 4, + + /// + /// modifies a preceding CHAR or CHAR_ALT to + /// be an inclusive range ([a-z]) + /// + CHAR_RNG_UPPER = 5, + + /// + /// modifies a preceding CHAR or + /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + /// + CHAR_ALT = 6, + }; + + /// + /// An element of a grammar + /// + [StructLayout(LayoutKind.Sequential)] + [DebuggerDisplay("{Type} {Value}")] + public readonly struct LLamaGrammarElement + : IEquatable + { + /// + /// The type of this element + /// + public readonly LLamaGrammarElementType Type; + + /// + /// Unicode code point or rule ID + /// + public readonly uint Value; + + /// + /// Construct a new LLamaGrammarElement + /// + /// + /// + public LLamaGrammarElement(LLamaGrammarElementType type, uint value) + { + Type = type; + Value = value; + } + + /// + public bool Equals(LLamaGrammarElement other) + { + if (Type != other.Type) + return false; + + // No need to compare values for the END rule + if (Type == LLamaGrammarElementType.END) + return true; + + return Value == other.Value; + } + + /// + public override bool Equals(object? obj) + { + return obj is LLamaGrammarElement other && Equals(other); + } + + /// + public override int GetHashCode() + { + unchecked + { + var hash = 2999; + hash = hash * 7723 + (int)Type; + hash = hash * 7723 + (int)Value; + return hash; + } + } + + internal bool IsCharElement() + { + switch (Type) + { + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: + case LLamaGrammarElementType.CHAR_ALT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + return true; + default: + return false; + } + } + } +} diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs index 17ec035a..128e30aa 100644 --- a/LLama/Native/LLamaModelQuantizeParams.cs +++ b/LLama/Native/LLamaModelQuantizeParams.cs @@ -1,29 +1,40 @@ using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; namespace LLama.Native { + /// + /// Quantizer parameters used in the native API + /// public struct LLamaModelQuantizeParams { /// /// number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() /// public int nthread; + /// /// quantize to this llama_ftype /// public LLamaFtype ftype; + /// /// allow quantizing non-f32/f16 tensors /// - [MarshalAs(UnmanagedType.I1)] - public bool allow_requantize; + public bool allow_requantize + { + get => Convert.ToBoolean(_allow_requantize); + set => _allow_requantize = Convert.ToSByte(value); + } + private sbyte _allow_requantize; + /// /// quantize output.weight /// - [MarshalAs(UnmanagedType.I1)] - public bool quantize_output_tensor; + public bool quantize_output_tensor + { + get => Convert.ToBoolean(_quantize_output_tensor); + set => _quantize_output_tensor = Convert.ToSByte(value); + } + private sbyte _quantize_output_tensor; } } diff --git a/LLama/Native/LLamaTokenData.cs b/LLama/Native/LLamaTokenData.cs index a5ffda59..0d3a56fc 100644 --- a/LLama/Native/LLamaTokenData.cs +++ b/LLama/Native/LLamaTokenData.cs @@ -1,7 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; +using System.Runtime.InteropServices; namespace LLama.Native { diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 6e2c4a46..7a2965ed 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -2,6 +2,8 @@ using System.Buffers; using System.Runtime.InteropServices; +using llama_token = System.Int32; + namespace LLama.Native { /// @@ -15,9 +17,9 @@ namespace LLama.Native public readonly Memory data; /// - /// Indicates if `data` is sorted + /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_. /// - public readonly bool sorted; + public bool sorted; /// /// Create a new LLamaTokenDataArray @@ -29,6 +31,20 @@ namespace LLama.Native data = tokens; sorted = isSorted; } + + /// + /// Create a new LLamaTokenDataArray, copying the data from the given logits + /// + /// + /// + public static LLamaTokenDataArray Create(ReadOnlySpan logits) + { + var candidates = new LLamaTokenData[logits.Length]; + for (var token_id = 0; token_id < logits.Length; token_id++) + candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); + + return new LLamaTokenDataArray(candidates); + } } /// @@ -51,8 +67,12 @@ namespace LLama.Native /// /// Indicates if the items in the array are sorted /// - [MarshalAs(UnmanagedType.I1)] - public bool sorted; + public bool sorted + { + get => Convert.ToBoolean(_sorted); + set => _sorted = Convert.ToSByte(value); + } + private sbyte _sorted; /// /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs new file mode 100644 index 00000000..354ade3b --- /dev/null +++ b/LLama/Native/NativeApi.Grammar.cs @@ -0,0 +1,45 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + using llama_token = Int32; + + public unsafe partial class NativeApi + { + /// + /// Create a new grammar from the given set of grammar rules + /// + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); + + /// + /// Free all memory from the given SafeLLamaGrammarHandle + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_grammar_free(IntPtr grammar); + + /// + /// Apply constraints from grammar + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_sample_grammar(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaGrammarHandle grammar); + + /// + /// Accepts the sampled token into the grammar + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token); + } +} diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs index 8b201dde..d4ff5cf8 100644 --- a/LLama/Native/NativeApi.Quantize.cs +++ b/LLama/Native/NativeApi.Quantize.cs @@ -1,7 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; +using System.Runtime.InteropServices; namespace LLama.Native { @@ -16,6 +13,6 @@ namespace LLama.Native /// not great API - very likely to change /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public unsafe static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param); + public static extern unsafe int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param); } } diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index 45a9caf0..631b8713 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -7,6 +7,16 @@ namespace LLama.Native public unsafe partial class NativeApi { + /// + /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// + /// + /// A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_sample_classifier_free_guidance(SafeLLamaContextHandle ctx, LLamaTokenDataArrayNative candidates, SafeLLamaContextHandle guidanceCtx, float scale); + /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @@ -16,7 +26,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); + public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); /// /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. @@ -28,7 +38,17 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); + public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); + + /// + /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// + /// + /// A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_sample_classifier_free_guidance(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaContextHandle guidance_ctx, float scale); /// /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. @@ -98,7 +118,7 @@ namespace LLama.Native /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu); + public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); /// /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -110,7 +130,7 @@ namespace LLama.Native /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu); + public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); /// /// Selects the token with the highest probability. diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index edfb4152..e9666ea8 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,14 +1,28 @@ using System; +using System.Buffers; using System.Runtime.InteropServices; using System.Text; +using LLama.Common; using LLama.Exceptions; +#pragma warning disable IDE1006 // Naming Styles + namespace LLama.Native { using llama_token = Int32; - public unsafe partial class NativeApi + + /// + /// Callback from llama.cpp with log messages + /// + /// + /// + public delegate void LLamaLogCallback(ILLamaLogger.LogLevel level, string message); + + /// + /// Direct translation of the llama.cpp API + /// + public unsafe partial class NativeApi { - public static readonly int LLAMA_MAX_DEVICES = 1; static NativeApi() { try @@ -28,21 +42,50 @@ namespace LLama.Native } private const string libraryName = "libllama"; + /// + /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. + /// + /// [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_empty_call(); + /// + /// Create a LLamaContextParams with default values + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern LLamaContextParams llama_context_default_params(); + /// + /// Create a LLamaModelQuantizeParams with default values + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern LLamaModelQuantizeParams llama_model_quantize_default_params(); + /// + /// Check if memory mapping is supported + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_mmap_supported(); + /// + /// Check if memory lockingis supported + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_mlock_supported(); + /// + /// Export a static computation graph for context of 511 and batch size of 1 + /// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these + /// parameters here to keep things simple + /// IMPORTANT: do not use for anything else other than debugging and testing! + /// + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname); @@ -52,13 +95,20 @@ namespace LLama.Native /// Return NULL on failure /// /// - /// + /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_); + public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params); + /// + /// Create a new llama_context with the given model. + /// Return value should always be wrapped in SafeLLamaContextHandle! + /// + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_); + public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); /// /// not great API - very likely to change. @@ -69,7 +119,7 @@ namespace LLama.Native public static extern void llama_backend_init(bool numa); /// - /// Frees all allocated memory + /// Frees all allocated memory in the given llama_context /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] @@ -223,9 +273,6 @@ namespace LLama.Native /// /// Convert the provided text into tokens. - /// The tokens pointer must be large enough to hold the resulting tokens. - /// Returns the number of tokens on success, no more than n_max_tokens - /// Returns a negative number on failure - the number of tokens that would have been returned /// /// /// @@ -233,35 +280,72 @@ namespace LLama.Native /// /// /// - /// + /// Returns the number of tokens on success, no more than n_max_tokens. + /// Returns a negative number on failure - the number of tokens that would have been returned + /// public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) { - var bytes = encoding.GetBytes(text); - sbyte[] data = new sbyte[bytes.Length]; - for(int i = 0; i < bytes.Length; i++) + // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) + var byteCount = encoding.GetByteCount(text); + var array = ArrayPool.Shared.Rent(byteCount + 1); + try + { + // Convert to bytes + fixed (char* textPtr = text) + fixed (byte* arrayPtr = array) + { + encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); + } + + // Add a zero byte to the end to terminate the string + array[byteCount] = 0; + + // Do the actual tokenization + fixed (byte* arrayPtr = array) + fixed (llama_token* tokensPtr = tokens) + return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos); + } + finally { - data[i] = (sbyte)bytes[i]; - //if (bytes[i] < 128) - //{ - // data[i] = (sbyte)bytes[i]; - //} - //else - //{ - // data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1); - //} + ArrayPool.Shared.Return(array); } - return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos); } + /// + /// Convert the provided text into tokens. + /// + /// + /// + /// + /// + /// + /// Returns the number of tokens on success, no more than n_max_tokens. + /// Returns a negative number on failure - the number of tokens that would have been returned + /// [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos); + public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos); + /// + /// Get the number of tokens in the model vocabulary for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); + /// + /// Get the size of the context window for the model for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); + /// + /// Get the dimension of embedding vectors from the model for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_embd(SafeLLamaContextHandle ctx); @@ -295,18 +379,38 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token); + /// + /// Get the "Beginning of sentence" token + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_bos(); + public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx); + /// + /// Get the "End of sentence" token + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_eos(); + public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx); + /// + /// Get the "new line" token + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_nl(); + public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx); + /// + /// Print out timing information for this context + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_print_timings(SafeLLamaContextHandle ctx); + /// + /// Reset all collected timing information for this context + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_reset_timings(SafeLLamaContextHandle ctx); @@ -317,19 +421,60 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr llama_print_system_info(); + /// + /// Get the number of tokens in the model vocabulary + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model); + public static extern int llama_model_n_vocab(SafeLlamaModelHandle model); + /// + /// Get the size of the context window for the model + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model); + public static extern int llama_model_n_ctx(SafeLlamaModelHandle model); + /// + /// Get the dimension of embedding vectors from this model + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model); + public static extern int llama_model_n_embd(SafeLlamaModelHandle model); + /// + /// Convert a single token into text + /// + /// + /// + /// buffer to write string into + /// size of the buffer + /// The length writte, or if the buffer is too small a negative that indicates the length required [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken); + public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); + /// + /// Convert text into tokens + /// + /// + /// + /// + /// + /// + /// Returns the number of tokens on success, no more than n_max_tokens. + /// Returns a negative number on failure - the number of tokens that would have been returned + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); - } + + /// + /// Register a callback to receive llama log messages + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_log_set(LLamaLogCallback logCallback); + } } diff --git a/LLama/Native/NativeInfo.cs b/LLama/Native/NativeInfo.cs deleted file mode 100644 index 6711db78..00000000 --- a/LLama/Native/NativeInfo.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama.Native -{ - internal class NativeInfo - { - internal static readonly int LLAMA_FILE_VERSION = 1; - internal static readonly string LLAMA_FILE_MAGIC = "ggjt"; - internal static readonly string LLAMA_FILE_MAGIC_UNVERSIONED = "ggml"; - internal static readonly string LLAMA_SESSION_MAGIC = "ggsn"; - internal static readonly int LLAMA_SESSION_VERSION = 1; - } -} diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index fa54f73e..aa9c9439 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -8,7 +9,7 @@ namespace LLama.Native /// /// A safe wrapper around a llama_context /// - public class SafeLLamaContextHandle + public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields @@ -25,11 +26,13 @@ namespace LLama.Native /// /// Dimension of embedding vectors /// - public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; + public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; /// - /// This field guarantees that a reference to the model is held for as long as this handle is held + /// Get the model which this context is using /// + public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); + private SafeLlamaModelHandle? _model; #endregion @@ -55,7 +58,7 @@ namespace LLama.Native { // Decrement refcount on model _model?.DangerousRelease(); - _model = null; + _model = null!; NativeApi.llama_free(handle); SetHandle(IntPtr.Zero); @@ -69,7 +72,7 @@ namespace LLama.Native if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); - return _model; + return _model!; } /// @@ -87,6 +90,35 @@ namespace LLama.Native return new(ctx_ptr, model); } + + /// + /// Create a new llama context with a clone of the current llama context state + /// + /// + /// + public SafeLLamaContextHandle Clone(LLamaContextParams lparams) + { + // Allocate space to read the state of the current context + var stateSize = GetStateSize(); + var stateMemory = Marshal.AllocHGlobal((nint)stateSize); + try + { + // Copy state from this context into memory + GetState(stateMemory, stateSize); + + // Create a new context + var newCtx = Create(ModelHandle, lparams); + + // Copy state into new context + newCtx.SetState(stateMemory); + + return newCtx; + } + finally + { + Marshal.FreeHGlobal(stateMemory); + } + } #endregion /// @@ -136,7 +168,6 @@ namespace LLama.Native /// Rows: n_tokens
/// Cols: n_vocab ///
- /// /// public Span GetLogits() { @@ -152,7 +183,7 @@ namespace LLama.Native /// /// Convert a token into a string /// - /// + /// Token to decode into a string /// /// public string TokenToString(int token, Encoding encoding) @@ -161,13 +192,25 @@ namespace LLama.Native } /// - /// Convert a token into a span of bytes that could be decoded into a string + /// Append a single llama token to a string builder /// - /// - /// - public ReadOnlySpan TokenToSpan(int token) + /// Token to decode + /// + /// string builder to append the result to + public void TokenToString(int token, Encoding encoding, StringBuilder dest) { - return ThrowIfDisposed().TokenToSpan(token); + ThrowIfDisposed().TokenToString(token, encoding, dest); + } + + /// + /// Convert a single llama token into bytes + /// + /// Token to decode + /// A span to attempt to write into. If this is too small nothing will be written + /// The size of this token. **nothing will be written** if this is larger than `dest` + public int TokenToSpan(int token, Span dest) + { + return ThrowIfDisposed().TokenToSpan(token, dest); } /// @@ -177,13 +220,79 @@ namespace LLama.Native /// the number of tokens to use from previous eval calls /// /// Returns true on success - public bool Eval(Memory tokens, int n_past, int n_threads) + public bool Eval(ReadOnlySpan tokens, int n_past, int n_threads) { - using var pin = tokens.Pin(); unsafe { - return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; + fixed (int* pinned = tokens) + { + return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; + } } } + + #region state + /// + /// Get the size of the state, when saved as bytes + /// + public ulong GetStateSize() + { + return NativeApi.llama_get_state_size(this); + } + + /// + /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. + /// + /// Destination to write to + /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) + /// The number of bytes written to dest + /// Thrown if dest is too small + public unsafe ulong GetState(byte* dest, ulong size) + { + return GetState(new IntPtr(dest), size); + } + + /// + /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. + /// + /// Destination to write to + /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) + /// The number of bytes written to dest + /// Thrown if dest is too small + public ulong GetState(IntPtr dest, ulong size) + { + var required = GetStateSize(); + if (size < required) + throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); + + unsafe + { + return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); + } + } + + /// + /// Set the raw state of this context + /// + /// The pointer to read the state from + /// Number of bytes read from the src pointer + public unsafe ulong SetState(byte* src) + { + return SetState(new IntPtr(src)); + } + + /// + /// Set the raw state of this context + /// + /// The pointer to read the state from + /// Number of bytes read from the src pointer + public ulong SetState(IntPtr src) + { + unsafe + { + return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); + } + } + #endregion } } diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs new file mode 100644 index 00000000..ed1c15c8 --- /dev/null +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -0,0 +1,106 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using LLama.Exceptions; +using LLama.Grammars; + +namespace LLama.Native +{ + /// + /// A safe reference to a `llama_grammar` + /// + public class SafeLLamaGrammarHandle + : SafeLLamaHandleBase + { + #region construction/destruction + /// + /// + /// + /// + internal SafeLLamaGrammarHandle(IntPtr handle) + : base(handle) + { + } + + /// + protected override bool ReleaseHandle() + { + NativeApi.llama_grammar_free(handle); + SetHandle(IntPtr.Zero); + return true; + } + + /// + /// Create a new llama_grammar + /// + /// A list of list of elements, each inner list makes up one grammar rule + /// The index (in the outer list) of the start rule + /// + /// + public static SafeLLamaGrammarHandle Create(IReadOnlyList rules, ulong start_rule_index) + { + unsafe + { + var totalElements = rules.Sum(a => a.Elements.Count); + var nrules = (ulong)rules.Count; + + // Borrow an array large enough to hold every single element + // and another array large enough to hold a pointer to each rule + var allElements = ArrayPool.Shared.Rent(totalElements); + var pointers = ArrayPool.Shared.Rent(rules.Count); + try + { + fixed (LLamaGrammarElement* allElementsPtr = allElements) + { + var elementIndex = 0; + var pointerIndex = 0; + foreach (var rule in rules) + { + // Save a pointer to the start of this rule + pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); + + // Copy all of the rule elements into the flat array + foreach (var element in rule.Elements) + allElementsPtr[elementIndex++] = element; + } + + // Sanity check some things that should be true if the copy worked as planned + Debug.Assert((ulong)pointerIndex == nrules); + Debug.Assert(elementIndex == totalElements); + + // Make the actual call through to llama.cpp + fixed (void* ptr = pointers) + { + return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); + } + } + } + finally + { + ArrayPool.Shared.Return(allElements); + ArrayPool.Shared.Return(pointers); + } + } + } + + /// + /// Create a new llama_grammar + /// + /// rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element) + /// total number of rules + /// index of the start rule of the grammar + /// + /// + public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) + { + var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); + if (grammar_ptr == IntPtr.Zero) + throw new RuntimeError("Failed to create grammar from rules"); + + return new(grammar_ptr); + } + #endregion + } +} diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index dbb1b070..7074fddb 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Text; using LLama.Exceptions; @@ -7,7 +8,7 @@ namespace LLama.Native /// /// A reference to a set of llama model weights /// - public class SafeLlamaModelHandle + public sealed class SafeLlamaModelHandle : SafeLLamaHandleBase { /// @@ -23,14 +24,14 @@ namespace LLama.Native /// /// Dimension of embedding vectors /// - public int EmbeddingCount { get; } + public int EmbeddingSize { get; } internal SafeLlamaModelHandle(IntPtr handle) : base(handle) { - VocabCount = NativeApi.llama_n_vocab_from_model(this); - ContextSize = NativeApi.llama_n_ctx_from_model(this); - EmbeddingCount = NativeApi.llama_n_embd_from_model(this); + VocabCount = NativeApi.llama_model_n_vocab(this); + ContextSize = NativeApi.llama_model_n_ctx(this); + EmbeddingSize = NativeApi.llama_model_n_embd(this); } /// @@ -82,17 +83,20 @@ namespace LLama.Native #region tokenize /// - /// Convert a single llama token into string bytes + /// Convert a single llama token into bytes /// - /// - /// - public ReadOnlySpan TokenToSpan(int llama_token) + /// Token to decode + /// A span to attempt to write into. If this is too small nothing will be written + /// The size of this token. **nothing will be written** if this is larger than `dest` + public int TokenToSpan(int llama_token, Span dest) { unsafe { - var bytes = new ReadOnlySpan(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue); - var terminator = bytes.IndexOf((byte)0); - return bytes.Slice(0, terminator); + fixed (byte* destPtr = dest) + { + var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length); + return Math.Abs(length); + } } } @@ -104,16 +108,54 @@ namespace LLama.Native /// public string TokenToString(int llama_token, Encoding encoding) { - var span = TokenToSpan(llama_token); + unsafe + { + var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); + if (length == 0) + return ""; - if (span.Length == 0) - return ""; + Span bytes = stackalloc byte[-length]; + + fixed (byte* bytePtr = bytes) + { + var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); + Debug.Assert(written == bytes.Length); + + return encoding.GetString(bytePtr, bytes.Length); + } + } + } + /// + /// Append a single llama token to a string builder + /// + /// Token to decode + /// + /// string builder to append the result to + public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest) + { unsafe { - fixed (byte* ptr = &span[0]) + var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); + if (length == 0) + return; + + Span bytes = stackalloc byte[-length]; + fixed (byte* bytePtr = bytes) { - return encoding.GetString(ptr, span.Length); + // Decode into bytes + var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); + Debug.Assert(written == bytes.Length); + + // Decode into chars + var charCount = encoding.GetCharCount(bytePtr, bytes.Length); + Span chars = stackalloc char[charCount]; + fixed (char* charPtr = chars) + encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length); + + // Write it to the output + for (var i = 0; i < chars.Length; i++) + dest.Append(chars[i]); } } } @@ -150,12 +192,24 @@ namespace LLama.Native var tokens = new int[count]; fixed (int* tokensPtr = &tokens[0]) { - count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); + NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); return tokens; } } } } #endregion + + #region context + /// + /// Create a new context for this model + /// + /// + /// + public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) + { + return SafeLLamaContextHandle.Create(this, @params); + } + #endregion } } diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index abe01b15..e26bf971 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -1,10 +1,28 @@ using System; +#pragma warning disable IDE1006 // Naming Styles + namespace LLama.Native { using llama_token = Int32; + + /// + /// Direct translation of the llama.cpp sampling API + /// public unsafe class SamplingApi { + /// + /// Apply grammar rules to candidate tokens + /// + /// + /// + /// + public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) + { + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_grammar(ctx, ref st, grammar); + } + /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @@ -13,10 +31,25 @@ namespace LLama.Native /// /// /// - public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) + [Obsolete("last_tokens_size parameter is no longer needed")] + public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty) + { + llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); + } + + /// + /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// + /// + /// Pointer to LLamaTokenDataArray + /// + /// + public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float penalty) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty); + using var last_tokens_handle = last_tokens.Pin(); + + NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); } /// @@ -28,10 +61,26 @@ namespace LLama.Native /// /// /// - public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) + [Obsolete("last_tokens_size parameter is no longer needed")] + public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) + { + llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); + } + + /// + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// + /// + /// Pointer to LLamaTokenDataArray + /// + /// + /// + public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float alpha_frequency, float alpha_presence) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence); + using var last_tokens_handle = last_tokens.Pin(); + + NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); } /// @@ -97,6 +146,13 @@ namespace LLama.Native NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); } + /// + /// Sample with temperature. + /// As temperature increases, the prediction becomes diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual + /// + /// + /// + /// public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); @@ -116,10 +172,7 @@ namespace LLama.Native public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - fixed(float* pmu = &mu) - { - return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu); - } + return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); } /// @@ -134,10 +187,7 @@ namespace LLama.Native public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - fixed (float* pmu = &mu) - { - return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu); - } + return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); } /// diff --git a/LLama/OldVersion/ChatSession.cs b/LLama/OldVersion/ChatSession.cs index 1bf954fa..f8409d30 100644 --- a/LLama/OldVersion/ChatSession.cs +++ b/LLama/OldVersion/ChatSession.cs @@ -1,10 +1,13 @@ using System; using System.Collections.Generic; using System.IO; -using System.Text; + +#pragma warning disable +// ReSharper disable all namespace LLama.OldVersion { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public class ChatSession where T : IChatModel { IChatModel _model; diff --git a/LLama/OldVersion/IChatModel.cs b/LLama/OldVersion/IChatModel.cs index 7fbd898b..de32fc09 100644 --- a/LLama/OldVersion/IChatModel.cs +++ b/LLama/OldVersion/IChatModel.cs @@ -1,9 +1,12 @@ using System; using System.Collections.Generic; -using System.Text; + +#pragma warning disable +// ReSharper disable all namespace LLama.OldVersion { + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public interface IChatModel { string Name { get; } diff --git a/LLama/OldVersion/LLamaEmbedder.cs b/LLama/OldVersion/LLamaEmbedder.cs index 823c4437..7b6aedb6 100644 --- a/LLama/OldVersion/LLamaEmbedder.cs +++ b/LLama/OldVersion/LLamaEmbedder.cs @@ -1,12 +1,15 @@ using LLama.Native; using System; -using System.Collections.Generic; -using System.Text; using LLama.Exceptions; +#pragma warning disable +// ReSharper disable all + namespace LLama.OldVersion { - public class LLamaEmbedder : IDisposable + [Obsolete("The entire LLama.OldVersion namespace will be removed")] + public class LLamaEmbedder + : IDisposable { SafeLLamaContextHandle _ctx; diff --git a/LLama/OldVersion/LLamaModel.cs b/LLama/OldVersion/LLamaModel.cs index bf400ba4..523b9553 100644 --- a/LLama/OldVersion/LLamaModel.cs +++ b/LLama/OldVersion/LLamaModel.cs @@ -9,10 +9,16 @@ using System.Linq; using System.Text; using LLama.Common; +#pragma warning disable +// ReSharper disable all + namespace LLama.OldVersion { using llama_token = Int32; - public class LLamaModel : IChatModel, IDisposable + + [Obsolete("The entire LLama.OldVersion namespace will be removed")] + public class LLamaModel + : IChatModel, IDisposable { LLamaParams _params; SafeLLamaContextHandle _ctx; @@ -27,7 +33,6 @@ namespace LLama.OldVersion bool _is_interacting; bool _is_antiprompt; bool _input_echo; - bool _verbose; // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session // if we loaded a session with at least 75% similarity. It's currently just used to speed up the @@ -40,17 +45,8 @@ namespace LLama.OldVersion List _embed; public string Name { get; set; } - public bool Verbose - { - get - { - return _verbose; - } - set - { - _verbose = value; - } - } + public bool Verbose { get; set; } + public SafeLLamaContextHandle NativeHandle => _ctx; /// @@ -173,7 +169,7 @@ namespace LLama.OldVersion { Name = name; _params = @params; - _verbose = verbose; + Verbose = verbose; _ctx = Utils.llama_init_from_gpt_params(ref _params); // Add a space in front of the first character to match OG llama tokenizer behavior @@ -509,7 +505,7 @@ namespace LLama.OldVersion } if (_is_interacting) { - if (_verbose) + if (Verbose) { LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it."); } @@ -620,7 +616,7 @@ namespace LLama.OldVersion NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); } - llama_token id = 0; + llama_token id; { var n_vocab = NativeApi.llama_n_vocab(_ctx); @@ -638,7 +634,7 @@ namespace LLama.OldVersion LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); // Apply penalties - float nl_logit = logits[NativeApi.llama_token_nl()]; + float nl_logit = logits[NativeApi.llama_token_nl(_ctx)]; var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx); SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(), @@ -648,7 +644,7 @@ namespace LLama.OldVersion (ulong)last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[NativeApi.llama_token_nl()] = nl_logit; + logits[NativeApi.llama_token_nl(_ctx)] = nl_logit; } if (temp <= 0) @@ -688,7 +684,7 @@ namespace LLama.OldVersion } // replace end of text token with newline token when in interactive mode - if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct) + if (id == NativeApi.llama_token_eos(_ctx) && _params.interactive && !_params.instruct) { id = _llama_token_newline[0]; if (_params.antiprompt.Count != 0) @@ -764,7 +760,7 @@ namespace LLama.OldVersion break; } - if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos()) + if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos(_ctx)) { if (_params.instruct) { diff --git a/LLama/OldVersion/LLamaParams.cs b/LLama/OldVersion/LLamaParams.cs index a2d677d8..2fa512ad 100644 --- a/LLama/OldVersion/LLamaParams.cs +++ b/LLama/OldVersion/LLamaParams.cs @@ -1,9 +1,14 @@ using System; using System.Collections.Generic; +#pragma warning disable +// ReSharper disable all + namespace LLama.OldVersion { using llama_token = Int32; + + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public struct LLamaParams { public int seed; // RNG seed @@ -58,7 +63,7 @@ namespace LLama.OldVersion public LLamaParams(int seed = 0, int n_threads = -1, int n_predict = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1, - Dictionary logit_bias = null, int top_k = 40, float top_p = 0.95f, + Dictionary? logit_bias = null, int top_k = 40, float top_p = 0.95f, float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f, int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f, int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, diff --git a/LLama/OldVersion/LLamaTypes.cs b/LLama/OldVersion/LLamaTypes.cs index d0bd4ad7..0cc4ed59 100644 --- a/LLama/OldVersion/LLamaTypes.cs +++ b/LLama/OldVersion/LLamaTypes.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; -using System.Text; + +#pragma warning disable +// ReSharper disable all namespace LLama.OldVersion { @@ -9,33 +11,49 @@ namespace LLama.OldVersion Human, Assistant } + + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record EmbeddingUsage(int PromptTokens, int TotalTokens); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record EmbeddingData(int Index, string Object, float[] Embedding); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record Embedding(string Object, string Model, EmbeddingData[] Data, EmbeddingUsage Usage); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record CompletionLogprobs(int[] TextOffset, float[] TokenLogProbs, string[] Tokens, Dictionary[] TopLogprobs); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record CompletionChoice(string Text, int Index, CompletionLogprobs? Logprobs, string? FinishReason); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record CompletionUsage(int PromptTokens, int CompletionTokens, int TotalTokens); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record CompletionChunk(string Id, string Object, int Created, string Model, CompletionChoice[] Choices); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record Completion(string Id, string Object, int Created, string Model, CompletionChoice[] Choices, CompletionUsage Usage); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatCompletionMessage(ChatRole Role, string Content, string? Name = null); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatCompletionChoice(int Index, ChatCompletionMessage Message, string? FinishReason); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatCompletion(string Id, string Object, int Created, string Model, ChatCompletionChoice[] Choices, CompletionUsage Usage); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatCompletionChunkDelta(string? Role, string? Content); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatCompletionChunkChoice(int Index, ChatCompletionChunkDelta Delta, string? FinishReason); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatCompletionChunk(string Id, string Model, string Object, int Created, ChatCompletionChunkChoice[] Choices); + [Obsolete("The entire LLama.OldVersion namespace will be removed")] public record ChatMessageRecord(ChatCompletionMessage Message, DateTime Time); } diff --git a/LLama/OldVersion/Utils.cs b/LLama/OldVersion/Utils.cs index df8adddd..5aa7876f 100644 --- a/LLama/OldVersion/Utils.cs +++ b/LLama/OldVersion/Utils.cs @@ -3,14 +3,17 @@ using System; using System.Collections.Generic; using System.Text; using LLama.Exceptions; -using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; using System.IO; +#pragma warning disable +// ReSharper disable all + namespace LLama.OldVersion { using llama_token = Int32; + internal static class Utils { public static SafeLLamaContextHandle llama_init_from_gpt_params(ref LLamaParams @params) @@ -54,7 +57,7 @@ namespace LLama.OldVersion return res.Take(n).ToList(); } - public unsafe static Span llama_get_logits(SafeLLamaContextHandle ctx, int length) + public static unsafe Span llama_get_logits(SafeLLamaContextHandle ctx, int length) { var logits = NativeApi.llama_get_logits(ctx); return new Span(logits, length); @@ -65,22 +68,26 @@ namespace LLama.OldVersion #if NET6_0_OR_GREATER return Marshal.PtrToStringUTF8(ptr); #else - byte* tp = (byte*)ptr.ToPointer(); - List bytes = new(); - while (true) + unsafe { - byte c = *tp++; - if (c == '\0') + byte* tp = (byte*)ptr.ToPointer(); + List bytes = new(); + while (true) { - break; - } - else - { - bytes.Add(c); + byte c = *tp++; + if (c == '\0') + { + break; + } + else + { + bytes.Add(c); + } } + return Encoding.UTF8.GetString(bytes.ToArray()); } - return Encoding.UTF8.GetString(bytes.ToArray()); #endif } + } } diff --git a/LLama/ResettableLLamaModel.cs b/LLama/ResettableLLamaModel.cs deleted file mode 100644 index d9b4e822..00000000 --- a/LLama/ResettableLLamaModel.cs +++ /dev/null @@ -1,43 +0,0 @@ -using LLama.Abstractions; -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama -{ - /// - /// A LLamaModel what could be reset. Note that using this class will consume about 10% more memories. - /// - public class ResettableLLamaModel : LLamaModel - { - /// - /// The initial state of the model - /// - public State OriginalState { get; set; } - /// - /// - /// - /// - /// - public ResettableLLamaModel(IModelParams Params, string encoding = "UTF-8") : base(Params, encoding) - { - OriginalState = GetState(); - } - - /// - /// Reset the state to the initial state. - /// - public void Reset() - { - LoadState(OriginalState); - } - - /// - public override void Dispose() - { - OriginalState.Dispose(); - - base.Dispose(); - } - } -} diff --git a/LLama/Utils.cs b/LLama/Utils.cs index de363a3e..f3584c81 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -9,30 +9,35 @@ using LLama.Extensions; namespace LLama { using llama_token = Int32; + + /// + /// Assorted llama utilities + /// public static class Utils { + [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")] + #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) + #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { - using (@params.ToLlamaContextParams(out var lparams)) - { - var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - var ctx = SafeLLamaContextHandle.Create(model, lparams); + using var weights = LLamaWeights.LoadFromFile(@params); - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); - - return ctx; - } + using (@params.ToLlamaContextParams(out var lparams)) + return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); } [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] + #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) + #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { return ctx.Tokenize(text, add_bos, encoding); } [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] + #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static Span GetLogits(SafeLLamaContextHandle ctx, int length) + #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { if (length != ctx.VocabCount) throw new ArgumentException("length must be the VocabSize"); @@ -41,33 +46,41 @@ namespace LLama } [Obsolete("Use SafeLLamaContextHandle Eval method instead")] + #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) + #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { - var slice = tokens.AsMemory().Slice(startIndex, n_tokens); + var slice = tokens.AsSpan().Slice(startIndex, n_tokens); return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; } [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] + #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding) + #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { return ctx.TokenToString(token, encoding); } [Obsolete("No longer used internally by LlamaSharp")] + #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static string PtrToString(IntPtr ptr, Encoding encoding) + #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { #if NET6_0_OR_GREATER + // ReSharper disable once PossibleUnintendedReferenceComparison if(encoding == Encoding.UTF8) { - return Marshal.PtrToStringUTF8(ptr); + return Marshal.PtrToStringUTF8(ptr)!; } + // ReSharper disable once PossibleUnintendedReferenceComparison else if(encoding == Encoding.Unicode) { - return Marshal.PtrToStringUni(ptr); + return Marshal.PtrToStringUni(ptr)!; } else { - return Marshal.PtrToStringAuto(ptr); + return Marshal.PtrToStringAuto(ptr)!; } #else unsafe @@ -90,5 +103,6 @@ namespace LLama } #endif } + } } diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal index 8d26b5ec..82e1a0c7 100644 --- a/LLama/runtimes/ggml-metal.metal +++ b/LLama/runtimes/ggml-metal.metal @@ -18,46 +18,11 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { - const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const half d = x[i].d; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) { - const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const half d = x[i].d; - const half m = x[i].m; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; kernel void kernel_add( device const float * src0, @@ -128,7 +93,12 @@ kernel void kernel_gelu( device float * dst, uint tpig[[thread_position_in_grid]]) { float x = src0[tpig]; - dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } kernel void kernel_soft_max( @@ -219,54 +189,6 @@ kernel void kernel_diag_mask_inf( } } -kernel void kernel_get_rows_f16( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - for (int j = 0; j < ne00; j++) { - dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j]; - } -} - -kernel void kernel_get_rows_q4_0( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_0( - (device const block_q4_0 *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q4_1( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_1( - (device const block_q4_1 *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - kernel void kernel_norm( device const void * src0, device float * dst, @@ -432,14 +354,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, - uint2 tgpig, uint tiisg, uint sgitg) { + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, + uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * nsg + sgitg) * nr; - device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[16]; // src1 vector cache float sumf[nr]={0.f}; @@ -470,7 +394,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + first_row + row] = tot; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } } @@ -480,13 +404,17 @@ kernel void kernel_mul_mat_q4_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -494,13 +422,79 @@ kernel void kernel_mul_mat_q4_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mat_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float sumf[nr]={0.f}; + + const int ix = tiisg/2; + const int il = tiisg%2; + + device const float * yb = y + ix * QK8_0 + 16*il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + for (int i = 0; i < 16; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + 16*il; + float sumq = 0.f; + for (int iq = 0; iq < 16; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += QK8_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } } kernel void kernel_mul_mat_f16_f32( @@ -554,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32( } } - kernel void kernel_alibi_f32( device const float * src0, device float * dst, @@ -650,7 +643,25 @@ kernel void kernel_rope( dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // TODO: implement + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } } } @@ -869,354 +880,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { return r; } -//========================================== dequantization ============================= - -static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = x[i].d; - const float min = x[i].dmin; - - device const uint8_t * q = x[i].qs; - -#if QK_K == 256 - int is = 0; - float dl, ml; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - uint8_t sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; - - sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; - - shift += 2; - } - q += 32; - } -#else - float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); - float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); - float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); - float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); - for (int l = 0; l < 16; ++l) { - y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1; - y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2; - y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3; - y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4; - } - y += QK_K; -#endif - - } -} - -static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - -#if QK_K == 256 - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - uint16_t aux[8]; - thread const int8_t * scales = (thread const int8_t*)aux; - - for (int i = 0; i < nb; i++) { - - const float d_all = (float)(x[i].d); - - device const uint8_t * q = x[i].qs; - device const uint8_t * h = x[i].hmask; - uint8_t m = 1; - - device const uint16_t * a = (device const uint16_t *)x[i].scales; - aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); - aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4); - aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4); - aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4); - aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4); - aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4); - aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4); - aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4); - - int is = 0; - float dl; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) { - *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); - } - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) { - *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); - } - - shift += 2; - m <<= 1; - } - q += 32; - } - } -#else - for (int i = 0; i < nb; i++) { - - const float d_all = (float)(x[i].d); - - device const uint8_t * q = x[i].qs; - device const uint8_t * hm = x[i].hmask; - - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); - - for (int l = 0; l < 8; ++l) { - uint8_t h = hm[l]; - y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); - y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); - y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); - y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); - y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); - y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); - y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); - y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); - } - y += QK_K; - } -#endif - -} - -static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - device const uint8_t * q = x[i].qs; - -#if QK_K == 256 - const float d = x[i].d; - const float min = x[i].dmin; - - device const uint8_t * scales = x[i].scales; - - int is = 0; - for (int j = 0; j < QK_K; j += 64) { - const uchar4 sc = get_scale_min_k4(is, scales); - const float d1 = d * sc[0]; const float m1 = min * sc[1]; - const float d2 = d * sc[2]; const float m2 = min * sc[3]; - for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; - q += 32; is += 2; - } -#else - device const uint8_t * s = x[i].scales; - device const half2 * dh = (device const half2 *)x[i].d; - const float2 d = (float2)dh[0]; - const float d1 = d[0] * (s[0] & 0xF); - const float d2 = d[0] * (s[1] & 0xF); - const float m1 = d[1] * (s[0] >> 4); - const float m2 = d[1] * (s[1] >> 4); - for (int l = 0; l < 32; ++l) { - y[l+ 0] = d1 * (q[l] & 0xF) - m1; - y[l+32] = d2 * (q[l] >> 4) - m2; - } - y += QK_K; -#endif - - } -} - -static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - -#if QK_K == 256 - for (int i = 0; i < nb; i++) { - - const float d = (float)(x[i].d); - const float min = (float)(x[i].dmin); - - device const uint8_t * ql = x[i].qs; - device const uint8_t * qh = x[i].qh; - - int is = 0; - uint8_t u1 = 1, u2 = 2; - for (int j = 0; j < QK_K; j += 64) { - const uchar4 sc = get_scale_min_k4(is, x[i].scales); - const float d1 = d * sc[0]; const float m1 = min * sc[1]; - const float d2 = d * sc[2]; const float m2 = min * sc[3]; - for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; - ql += 32; is += 2; - u1 <<= 2; u2 <<= 2; - } - } -#else - for (int i = 0; i < nb; i++) { - - const float d = (float)x[i].d; - - device const uint8_t * ql = x[i].qs; - device const uint8_t * qh = x[i].qh; - device const int8_t * sc = x[i].scales; - - for (int l = 0; l < 8; ++l) { - y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); - y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); - y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); - y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); - y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); - y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); - y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); - y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); - } - y += QK_K; - } -#endif - -} - -static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - device const uint8_t * ql = x[i].ql; - device const uint8_t * qh = x[i].qh; - device const int8_t * sc = x[i].scales; - - const float d = x[i].d; - -#if QK_K == 256 - for (int n = 0; n < QK_K; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l + 0] = d * sc[is + 0] * q1; - y[l + 32] = d * sc[is + 2] * q2; - y[l + 64] = d * sc[is + 4] * q3; - y[l + 96] = d * sc[is + 6] * q4; - } - y += 128; - ql += 64; - qh += 32; - sc += 8; - } -#else - for (int l = 0; l < 16; ++l) { - const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l+ 0] = d * sc[0] * q1; - y[l+16] = d * sc[1] * q2; - y[l+32] = d * sc[2] * q3; - y[l+48] = d * sc[3] * q4; - } - y += 64; -#endif - } -} - -kernel void kernel_get_rows_q2_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q2_K( - (device const block_q2_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q3_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q3_K( - (device const block_q3_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q4_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_K( - (device const block_q4_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q5_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q5_K( - (device const block_q5_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q6_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q6_K( - (device const block_q6_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - //====================================== dot products ========================= kernel void kernel_mul_mat_q2_K_f32( @@ -1224,21 +887,27 @@ kernel void kernel_mul_mat_q2_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -1351,7 +1020,7 @@ kernel void kernel_mul_mat_q2_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = all_sum; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1362,10 +1031,14 @@ kernel void kernel_mul_mat_q3_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne1, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1373,11 +1046,12 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int64_t r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[16]; @@ -1465,7 +1139,7 @@ kernel void kernel_mul_mat_q3_K_f32( const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } } } @@ -1475,10 +1149,14 @@ kernel void kernel_mul_mat_q3_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne1, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1486,11 +1164,12 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int64_t r2 = tgpig.z; const int row = 2 * r0 + sgitg; - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 @@ -1529,7 +1208,7 @@ kernel void kernel_mul_mat_q3_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; } } @@ -1541,10 +1220,14 @@ kernel void kernel_mul_mat_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1560,10 +1243,12 @@ kernel void kernel_mul_mat_q4_K_f32( const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; @@ -1630,7 +1315,7 @@ kernel void kernel_mul_mat_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = all_sum; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1640,10 +1325,14 @@ kernel void kernel_mul_mat_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1653,10 +1342,12 @@ kernel void kernel_mul_mat_q4_K_f32( const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[8]; float yh[8]; float sumf[N_DST]={0.f}, all_sum; @@ -1712,7 +1403,7 @@ kernel void kernel_mul_mat_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = all_sum; + dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1723,9 +1414,14 @@ kernel void kernel_mul_mat_q5_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1733,11 +1429,12 @@ kernel void kernel_mul_mat_q5_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float sumf[2]={0.f}; @@ -1871,7 +1568,7 @@ kernel void kernel_mul_mat_q5_K_f32( for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } } @@ -1882,9 +1579,14 @@ kernel void kernel_mul_mat_q6_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1897,11 +1599,12 @@ kernel void kernel_mul_mat_q6_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int r2 = tgpig.z; const int row = 2 * r0 + sgitg; - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float sumf = 0; @@ -1967,6 +1670,380 @@ kernel void kernel_mul_mat_q6_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; + } +} + +//============================= templates and their specializations ============================= + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const half d = il ? (xb->d / 16.h) : xb->d; + const half m = il ? ( -8.h * 16.h) : -8.h; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = il ? 0xF000 : 0x0F00; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; + reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const half d = il ? (xb->d / 16.h) : xb->d; + const half m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = il ? 0xF000 : 0x0F00; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; + reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i=0;i<16;i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const half d = xb->d; + const half min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + half dl, ml; + uint8_t sc = xb->scales[il]; + +#if QK_K == 256 + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; +#endif + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; } } + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const float d_all = (float)(xb->d); + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ + (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + + il = (il/2)%4; + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); + } +#else + float kcoef = il&1 ? 1.f/16.f : 1.f; + uint16_t kmask = il&1 ? 0xF0 : 0x0F; + float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint8_t m = 1<<(il*2); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); + } +#endif +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + +#if QK_K == 256 + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il%4; + const uchar4 sc = get_scale_min_k4(is, xb->scales); + const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; + const float ml = il<2 ? min * sc[1] : min * sc[3]; +#else + q = q + 16 * (il&1); + device const uint8_t * s = xb->scales; + device const half2 * dh = (device const half2 *)xb->d; + const float2 d = (float2)dh[0]; + const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4); +#endif + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + +#if QK_K == 256 + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il%4; + const uchar4 sc = get_scale_min_k4(is, xb->scales); + const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; + const float ml = il<2 ? min * sc[1] : min * sc[3]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +#else + q = q + 16 * (il&1); + device const int8_t * s = xb->scales; + const float dl = xb->d * s[il]; + uint8_t m = 1<<(il*2); + const float coef = il<2 ? 1.f : 1.f/16.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); + } +#endif +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const float d_all = (float)(xb->d); + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2)%4; +#else + ql = ql + 16 * (il&1); + float sc = scales[il]; +#endif + for (int i = 0; i < 16; ++i) { + uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const float coef = il>1 ? 1.f/16.f : 1.f; + float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \ + ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef; + reg[i/4][i%4] = d_all * sc * q * coef; + } +} + +template +kernel void kernel_get_rows( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tptg[[threads_per_threadgroup]]) { + const int i = tgpig; + const int r = ((device int32_t *) src1)[i]; + + for (int ind = tiitg; ind < ne00/16; ind += tptg) { + float4x4 temp; + dequantize_func( + ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; + } +} + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +kernel void kernel_mul_mm(device const uchar * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = ((threadgroup half *)shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ + + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + //load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \ + = *((device float2x4 *)y); + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + //load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg==0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +#if QK_K == 256 +#define QK_NL 16 +#else +#define QK_NL 4 +#endif + +typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ + constant uint64_t &, constant uint64_t &, uint, uint, uint); + +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; + +typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ + constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ + constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll index 81af173b..6ed31810 100644 Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so index 75b884dd..81733cdd 100644 Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll index e6ff0a30..f1a9fbdc 100644 Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so index 6d20557b..482fe2f2 100644 Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ diff --git a/LLama/runtimes/libllama-metal.dylib b/LLama/runtimes/libllama-metal.dylib old mode 100644 new mode 100755 index 7cd1f4ab..e9c2ee28 Binary files a/LLama/runtimes/libllama-metal.dylib and b/LLama/runtimes/libllama-metal.dylib differ diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll index 8432f664..a5f774f8 100644 Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib old mode 100644 new mode 100755 index e4d0f1c7..53318c38 Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so index 1d7226a6..e52d6bda 100644 Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ diff --git a/LLamaSharp.sln b/LLamaSharp.sln index 2e00196c..2a039d41 100644 --- a/LLamaSharp.sln +++ b/LLamaSharp.sln @@ -11,7 +11,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLamaSharp", "LLama\LLamaSh EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLama.WebAPI", "LLama.WebAPI\LLama.WebAPI.csproj", "{D3CEC57A-9027-4DA4-AAAC-612A1EB50ADF}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LLama.Web", "LLama.Web\LLama.Web.csproj", "{C3531DB2-1B2B-433C-8DE6-3541E3620DB1}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLama.Web", "LLama.Web\LLama.Web.csproj", "{C3531DB2-1B2B-433C-8DE6-3541E3620DB1}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLamaSharp.SemanticKernel", "LLama.SemanticKernel\LLamaSharp.SemanticKernel.csproj", "{D98F93E3-B344-4F9D-86BB-FDBF6768B587}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -83,6 +85,18 @@ Global {C3531DB2-1B2B-433C-8DE6-3541E3620DB1}.Release|Any CPU.Build.0 = Release|Any CPU {C3531DB2-1B2B-433C-8DE6-3541E3620DB1}.Release|x64.ActiveCfg = Release|Any CPU {C3531DB2-1B2B-433C-8DE6-3541E3620DB1}.Release|x64.Build.0 = Release|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Debug|x64.ActiveCfg = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Debug|x64.Build.0 = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.GPU|Any CPU.Build.0 = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.GPU|x64.ActiveCfg = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.GPU|x64.Build.0 = Debug|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Release|Any CPU.Build.0 = Release|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Release|x64.ActiveCfg = Release|Any CPU + {D98F93E3-B344-4F9D-86BB-FDBF6768B587}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/README.md b/README.md index 2bb6a17f..b4abc5fe 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ LLamaSharp provides two ways to run inference: `LLamaExecutor` and `ChatSession` using LLama.Common; using LLama; -string modelPath = "" // change it to your own model path +string modelPath = ""; // change it to your own model path var prompt = "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\r\n\r\nUser: Hello, Bob.\r\nBob: Hello. How may I help you today?\r\nUser: Please tell me the largest city in Europe.\r\nBob: Sure. The largest city in Europe is Moscow, the capital of Russia.\r\nUser:"; // use the "chat-with-bob" prompt here. // Initialize a chat session diff --git a/docs/ContributingGuide.md b/docs/ContributingGuide.md index c7f28b7c..1f3b3d47 100644 --- a/docs/ContributingGuide.md +++ b/docs/ContributingGuide.md @@ -33,11 +33,11 @@ When adding the feature, please take care of the namespace and the naming conven ## Find the problem and fix the BUG -If the issue is related to the LLM internal behaviors, such as endless generating the response, the best way to find the problem is to do comparison test between llama.cpp and LLamaSharp. +If the issue is related to the LLM internal behaviour, such as endless generating the response, the best way to find the problem is to do comparison test between llama.cpp and LLamaSharp. You could use exactly the same prompt, the same model and the same parameters to run the inference in llama.cpp and LLamaSharp respectively to see if it's really a problem caused by the implementation in LLamaSharp. -If the experiment showed that it worked well in llama.cpp but didn't in LLamaSharp, a the search for the problem could be started. While the reason of the problem could be various, the best way I think is to add log-print in the code of llama.cpp and use it in LLamaSharp after compilation. Thus, when running LLamaSharp, you could see what happened in the native library. +If the experiment showed that it worked well in llama.cpp but didn't in LLamaSharp, a search for the problem could be started. While the reason of the problem could be various, the best way I think is to add log-print in the code of llama.cpp and use it in LLamaSharp after compilation. Thus, when running LLamaSharp, you could see what happened in the native library. After finding out the reason, a painful but happy process comes. When working on the BUG fix, there's only one rule to follow, that is keeping the examples working well. If the modification fixed the BUG but impact on other functions, it would not be a good fix. diff --git a/docs/GetStarted.md b/docs/GetStarted.md index db46f2a3..b9248b57 100644 --- a/docs/GetStarted.md +++ b/docs/GetStarted.md @@ -54,8 +54,16 @@ using LLama; string modelPath = "" // change it to your own model path var prompt = "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\r\n\r\nUser: Hello, Bob.\r\nBob: Hello. How may I help you today?\r\nUser: Please tell me the largest city in Europe.\r\nBob: Sure. The largest city in Europe is Moscow, the capital of Russia.\r\nUser:"; // use the "chat-with-bob" prompt here. +// Load model +var parameters = new ModelParams(modelPath) +{ + ContextSize = 1024 +}; +using var model = LLamaWeights.LoadFromFile(parameters); + // Initialize a chat session -var ex = new InteractiveExecutor(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); +using var context = model.CreateContext(parameters); +var ex = new InteractiveExecutor(context); ChatSession session = new ChatSession(ex); // show the prompt diff --git a/docs/Tricks.md b/docs/Tricks.md index a75d6c21..4b72f440 100644 --- a/docs/Tricks.md +++ b/docs/Tricks.md @@ -1,11 +1,11 @@ # Tricks for FAQ -Sometimes, your application with LLM and LLamaSharp may have strange behaviors. Before opening an issue to report the BUG, the following tricks may worth a try. +Sometimes, your application with LLM and LLamaSharp may have strange behaviours. Before opening an issue to report the BUG, the following tricks may worth a try. ## Carefully set the anti-prompts -Anti-prompt can also be called as "Stop-keyword", which decides when to stop the response generation. Under interactive mode, the maximum tokens count is always not set, which makes the LLM generates responses infinitively. Therefore, setting anti-prompt correctly helps a lot to avoid the strange behaviors. For example, the prompt file `chat-with-bob.txt` has the following content: +Anti-prompt can also be called as "Stop-keyword", which decides when to stop the response generation. Under interactive mode, the maximum tokens count is always not set, which makes the LLM generates responses infinitively. Therefore, setting anti-prompt correctly helps a lot to avoid the strange behaviours. For example, the prompt file `chat-with-bob.txt` has the following content: ``` Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. @@ -19,7 +19,7 @@ User: Therefore, the anti-prompt should be set as "User:". If the last line of the prompt is removed, LLM will automatically generate a question (user) and a response (bob) for one time when running the chat session. Therefore, the antiprompt is suggested to be appended to the prompt when starting a chat session. -What if an extra line is appended? The string "User:" in the prompt will be followed with a char "\n". Thus when running the model, the automatic generation of a pair of question and response may appear because the anti-prompt is "User:" but the last token is "User:\n". As for whether it will appear, it's an undefined behavior, which depends on the implementation inside the `LLamaExecutor`. Anyway, since it may leads to unexpected behaviors, it's recommended to trim your prompt or carefully keep consistent with your anti-prompt. +What if an extra line is appended? The string "User:" in the prompt will be followed with a char "\n". Thus when running the model, the automatic generation of a pair of question and response may appear because the anti-prompt is "User:" but the last token is "User:\n". As for whether it will appear, it's an undefined behaviour, which depends on the implementation inside the `LLamaExecutor`. Anyway, since it may leads to unexpected behaviors, it's recommended to trim your prompt or carefully keep consistent with your anti-prompt. ## Pay attention to the length of prompt @@ -37,7 +37,7 @@ If your chat bot has bad performance, trying different executor will possibly ma ## Choose models weight depending on you task -The differences between modes may lead to much different behaviors under the same task. For example, if you're building a chat bot with non-English, a fine-tuned model specially for the language you want to use will have huge effect on the performance. +The differences between modes may lead to much different behaviours under the same task. For example, if you're building a chat bot with non-English, a fine-tuned model specially for the language you want to use will have huge effect on the performance. ## Set the layer count you want to offload to GPU