| @@ -0,0 +1,240 @@ | |||||
| name: Update Binaries | |||||
| on: | |||||
| workflow_dispatch: | |||||
| inputs: | |||||
| cublas: | |||||
| type: boolean | |||||
| description: Build CUBLAS binaries | |||||
| macos: | |||||
| type: boolean | |||||
| description: Build MacOS binaries | |||||
| push: | |||||
| branches: [cron_job] | |||||
| #schedule: | |||||
| # - cron: "22 22 * * 2" | |||||
| jobs: | |||||
| compile-linux: | |||||
| name: Compile (Linux) | |||||
| strategy: | |||||
| fail-fast: true | |||||
| matrix: | |||||
| include: | |||||
| - build: 'noavx' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON' | |||||
| - build: 'avx2' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON' | |||||
| - build: 'avx' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DBUILD_SHARED_LIBS=ON' | |||||
| - build: 'avx512' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON' | |||||
| runs-on: ubuntu-latest | |||||
| steps: | |||||
| - uses: actions/checkout@v3 | |||||
| with: | |||||
| repository: ggerganov/llama.cpp | |||||
| - name: Build | |||||
| id: cmake_build | |||||
| run: | | |||||
| mkdir build | |||||
| cd build | |||||
| cmake .. ${{ matrix.defines }} | |||||
| cmake --build . --config Release | |||||
| - uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: ./build/libllama.so | |||||
| name: llama-bin-linux-${{ matrix.build }}-x64.so | |||||
| compile-windows: | |||||
| name: Compile (Windows) | |||||
| strategy: | |||||
| fail-fast: true | |||||
| matrix: | |||||
| include: | |||||
| - build: 'noavx' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON' | |||||
| - build: 'avx2' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON' | |||||
| - build: 'avx' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DBUILD_SHARED_LIBS=ON' | |||||
| - build: 'avx512' | |||||
| defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON' | |||||
| runs-on: windows-latest | |||||
| steps: | |||||
| - uses: actions/checkout@v3 | |||||
| with: | |||||
| repository: ggerganov/llama.cpp | |||||
| - name: Build | |||||
| id: cmake_build | |||||
| run: | | |||||
| mkdir build | |||||
| cd build | |||||
| cmake .. ${{ matrix.defines }} | |||||
| cmake --build . --config Release | |||||
| - name: Upload artifacts | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: .\build\bin\Release\llama.dll | |||||
| name: llama-bin-win-${{ matrix.build }}-x64.dll | |||||
| compile-cublas: | |||||
| if: ${{ github.event.inputs.cublas }} | |||||
| name: Compile (cublas) | |||||
| strategy: | |||||
| fail-fast: false | |||||
| matrix: | |||||
| os: [ubuntu-latest, windows-latest] | |||||
| cuda: ['12.1.0', '11.7.1'] | |||||
| runs-on: ${{ matrix.os }} | |||||
| steps: | |||||
| - name: Clone | |||||
| id: checkout | |||||
| uses: actions/checkout@v3 | |||||
| with: | |||||
| repository: ggerganov/llama.cpp | |||||
| - uses: Jimver/cuda-toolkit@v0.2.10 | |||||
| id: cuda-toolkit | |||||
| with: | |||||
| cuda: ${{ matrix.cuda }} | |||||
| - name: Build | |||||
| id: cmake_build | |||||
| run: | | |||||
| mkdir build | |||||
| cd build | |||||
| cmake .. -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF | |||||
| cmake --build . --config Release | |||||
| ls -R | |||||
| - name: Upload artifacts (Windows) | |||||
| if: ${{ matrix.os == 'windows-latest' }} | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: .\build\bin\Release\llama.dll | |||||
| name: llama-bin-win-cublas-cu${{ matrix.cuda }}-x64.dll | |||||
| - name: Upload artifacts (Linux) | |||||
| if: ${{ matrix.os == 'ubuntu-latest' }} | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: ./build/libllama.so | |||||
| name: llama-bin-linux-cublas-cu${{ matrix.cuda }}-x64.so | |||||
| compile-macos: | |||||
| if: ${{ github.event.inputs.macos }} | |||||
| name: Compile (MacOS) | |||||
| runs-on: macos-latest | |||||
| strategy: | |||||
| fail-fast: true | |||||
| matrix: | |||||
| arch: [ | |||||
| "arm64" | |||||
| ] | |||||
| steps: | |||||
| - uses: actions/checkout@v3 | |||||
| with: | |||||
| repository: ggerganov/llama.cpp | |||||
| - name: Dependencies | |||||
| continue-on-error: true | |||||
| run: | | |||||
| brew update | |||||
| - name: Build | |||||
| id: cmake_build | |||||
| run: | | |||||
| mkdir build | |||||
| cd build | |||||
| cmake -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_OSX_ARCHITECTURES=${{ matrix.arch }} .. | |||||
| cmake --build . --config Release | |||||
| - name: Upload artifacts | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: ./build/libllama.dylib | |||||
| name: llama-bin-macos-${{ matrix.arch }}.dylib | |||||
| build-deps: | |||||
| runs-on: ubuntu-latest | |||||
| name: "Gather Binaries" | |||||
| if: ${{ always() }} | |||||
| needs: [ | |||||
| "compile-linux", | |||||
| "compile-macos", | |||||
| "compile-windows", | |||||
| "compile-cublas" | |||||
| ] | |||||
| steps: | |||||
| - uses: actions/download-artifact@v3 | |||||
| with: | |||||
| path: artifacts | |||||
| - name: Rearrange Files | |||||
| run: | | |||||
| ls -R | |||||
| mkdir deps | |||||
| mkdir deps/linux | |||||
| mkdir deps/linux/noavx | |||||
| cp artifacts/llama-bin-linux-noavx-x64.so/libllama.so deps/linux/noavx/libllama.so | |||||
| mkdir deps/linux/avx | |||||
| cp artifacts/llama-bin-linux-avx-x64.so/libllama.so deps/linux/avx/libllama.so | |||||
| mkdir deps/linux/avx2 | |||||
| cp artifacts/llama-bin-linux-avx2-x64.so/libllama.so deps/linux/avx2/libllama.so | |||||
| mkdir deps/linux/avx512 | |||||
| cp artifacts/llama-bin-linux-avx512-x64.so/libllama.so deps/linux/avx512/libllama.so | |||||
| mkdir deps/win | |||||
| mkdir deps/win/noavx | |||||
| cp artifacts/llama-bin-win-noavx-x64.dll/llama.dll deps/win/noavx/libllama.dll | |||||
| mkdir deps/win/avx | |||||
| cp artifacts/llama-bin-win-avx-x64.dll/llama.dll deps/win/avx/libllama.dll | |||||
| mkdir deps/win/avx2 | |||||
| cp artifacts/llama-bin-win-avx2-x64.dll/llama.dll deps/win/avx2/libllama.dll | |||||
| mkdir deps/win/avx512 | |||||
| cp artifacts/llama-bin-win-avx512-x64.dll/llama.dll deps/win/avx512/libllama.dll | |||||
| - name: Rearrange MacOS files | |||||
| if: ${{ github.event.inputs.macos }} | |||||
| run: | | |||||
| mkdir deps/macos-arm64 | |||||
| cp artifacts/llama-bin-macos-arm64.dylib/libllama.dylib deps/macos-arm64/libllama.dylib | |||||
| - name: Rearrange CUDA files | |||||
| if: ${{ github.event.inputs.cublas }} | |||||
| run: | | |||||
| mkdir cu11.7.1 | |||||
| cp artifacts/llama-bin-win-cublas-cu11.7.1-x64.dll/llama.dll cu11.7.1/libllama.dll | |||||
| cp artifacts/llama-bin-linux-cublas-cu11.7.1-x64.so/libllama.so cu11.7.1/libllama.so | |||||
| mkdir cu12.1.0 | |||||
| cp artifacts/llama-bin-win-cublas-cu12.1.0-x64.dll/llama.dll cu12.1.0/libllama.dll | |||||
| cp artifacts/llama-bin-linux-cublas-cu12.1.0-x64.so/libllama.so cu12.1.0/libllama.so | |||||
| - name: Upload artifacts | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: deps/ | |||||
| name: deps | |||||
| - name: Upload artifacts (CUDA12) | |||||
| if: ${{ github.event.inputs.cublas }} | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: cu12.1.0/ | |||||
| name: cu12.1.0 | |||||
| - name: Upload artifacts (CUDA11) | |||||
| if: ${{ github.event.inputs.cublas }} | |||||
| uses: actions/upload-artifact@v3 | |||||
| with: | |||||
| path: cu11.7.1/ | |||||
| name: cu11.7.1 | |||||
| - name: Remove Artifacts | |||||
| uses: geekyeggo/delete-artifact@v2 | |||||
| with: | |||||
| name: | | |||||
| llama-* | |||||
| @@ -344,4 +344,5 @@ test/TensorFlowNET.Examples/mnist | |||||
| site/ | site/ | ||||
| /LLama.Unittest/Models/*.bin | /LLama.Unittest/Models/*.bin | ||||
| /LLama.Unittest/Models/*.gguf | |||||
| @@ -0,0 +1,27 @@ | |||||
| # https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf | |||||
| root ::= object | |||||
| value ::= object | array | string | number | ("true" | "false" | "null") ws | |||||
| object ::= | |||||
| "{" ws ( | |||||
| string ":" ws value | |||||
| ("," ws string ":" ws value)* | |||||
| )? "}" ws | |||||
| array ::= | |||||
| "[" ws ( | |||||
| value | |||||
| ("," ws value)* | |||||
| )? "]" ws | |||||
| string ::= | |||||
| "\"" ( | |||||
| [^"\\] | | |||||
| "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | |||||
| )* "\"" ws | |||||
| number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | |||||
| # Optional space: by convention, applied in this grammar after literal chars when allowed | |||||
| ws ::= ([ \t\n] ws)? | |||||
| @@ -27,6 +27,11 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.SemanticKernel" Version="0.21.230828.2-preview" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\LLama.SemanticKernel\LLamaSharp.SemanticKernel.csproj" /> | |||||
| <ProjectReference Include="..\LLama\LLamaSharp.csproj" /> | <ProjectReference Include="..\LLama\LLamaSharp.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -49,6 +54,9 @@ | |||||
| <None Update="Assets\dan.txt"> | <None Update="Assets\dan.txt"> | ||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
| </None> | </None> | ||||
| <None Update="Assets\json.gbnf"> | |||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
| </None> | |||||
| <None Update="Assets\reason-act.txt"> | <None Update="Assets\reason-act.txt"> | ||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
| </None> | </None> | ||||
| @@ -1,9 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -12,15 +7,27 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | ||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); | |||||
| ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var executor = new InteractiveExecutor(context); | |||||
| var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The chat session has started. The role names won't be printed."); | Console.WriteLine("The chat session has started. The role names won't be printed."); | ||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| // show the prompt | |||||
| Console.Write(prompt); | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | ||||
| @@ -1,9 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -12,10 +7,20 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | ||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); | |||||
| ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var executor = new InteractiveExecutor(context); | |||||
| var session = new ChatSession(executor); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); | Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); | ||||
| @@ -1,9 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -12,7 +7,7 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var embedder = new LLamaEmbedder(new ModelParams(modelPath)); | var embedder = new LLamaEmbedder(new ModelParams(modelPath)); | ||||
| while (true) | while (true) | ||||
| @@ -0,0 +1,53 @@ | |||||
| using LLama.Common; | |||||
| using LLama.Grammars; | |||||
| namespace LLama.Examples.NewVersion | |||||
| { | |||||
| public class GrammarJsonResponse | |||||
| { | |||||
| public static void Run() | |||||
| { | |||||
| var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); | |||||
| var grammar = Grammar.Parse(gbnf, "root"); | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||||
| Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); | |||||
| Console.ForegroundColor = ConsoleColor.White; | |||||
| using var grammarInstance = grammar.CreateInstance(); | |||||
| var inferenceParams = new InferenceParams() | |||||
| { | |||||
| Temperature = 0.6f, | |||||
| AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" }, | |||||
| MaxTokens = 50, | |||||
| Grammar = grammarInstance | |||||
| }; | |||||
| while (true) | |||||
| { | |||||
| Console.Write("\nQuestion: "); | |||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| var prompt = Console.ReadLine(); | |||||
| Console.ForegroundColor = ConsoleColor.White; | |||||
| Console.Write("Answer: "); | |||||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | |||||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||||
| { | |||||
| Console.Write(text); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,9 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -12,10 +7,18 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var prompt = File.ReadAllText("Assets/dan.txt").Trim(); | var prompt = File.ReadAllText("Assets/dan.txt").Trim(); | ||||
| InstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024))); | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var executor = new InstructExecutor(context); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + | Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + | ||||
| @@ -26,7 +29,7 @@ namespace LLama.Examples.NewVersion | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||||
| foreach (var text in executor.Infer(prompt, inferenceParams)) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -1,21 +1,24 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| public class InteractiveModeExecute | public class InteractiveModeExecute | ||||
| { | { | ||||
| public async static Task Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); | |||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var ex = new InteractiveExecutor(context); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); | Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); | ||||
| @@ -1,10 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.OldVersion; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -13,10 +7,20 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | ||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); | |||||
| ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var ex = new InteractiveExecutor(context); | |||||
| var session = new ChatSession(ex); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result. Input \"save\" to save and reload the session."); | Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result. Input \"save\" to save and reload the session."); | ||||
| @@ -45,8 +49,8 @@ namespace LLama.Examples.NewVersion | |||||
| Console.WriteLine("Saved session!"); | Console.WriteLine("Saved session!"); | ||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| ex.Model.Dispose(); | |||||
| ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); | |||||
| ex.Context.Dispose(); | |||||
| ex = new(new LLamaContext(parameters)); | |||||
| session = new ChatSession(ex); | session = new ChatSession(ex); | ||||
| session.LoadSession(statePath); | session.LoadSession(statePath); | ||||
| @@ -1,9 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -12,10 +7,18 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | ||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var ex = new InteractiveExecutor(context); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); | Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); | ||||
| @@ -36,20 +39,20 @@ namespace LLama.Examples.NewVersion | |||||
| if (prompt == "save") | if (prompt == "save") | ||||
| { | { | ||||
| Console.Write("Your path to save model state: "); | Console.Write("Your path to save model state: "); | ||||
| string modelStatePath = Console.ReadLine(); | |||||
| ex.Model.SaveState(modelStatePath); | |||||
| var modelStatePath = Console.ReadLine(); | |||||
| ex.Context.SaveState(modelStatePath); | |||||
| Console.Write("Your path to save executor state: "); | Console.Write("Your path to save executor state: "); | ||||
| string executorStatePath = Console.ReadLine(); | |||||
| var executorStatePath = Console.ReadLine(); | |||||
| ex.SaveState(executorStatePath); | ex.SaveState(executorStatePath); | ||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("All states saved!"); | Console.WriteLine("All states saved!"); | ||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| var model = ex.Model; | |||||
| model.LoadState(modelStatePath); | |||||
| ex = new InteractiveExecutor(model); | |||||
| var ctx = ex.Context; | |||||
| ctx.LoadState(modelStatePath); | |||||
| ex = new InteractiveExecutor(ctx); | |||||
| ex.LoadState(executorStatePath); | ex.LoadState(executorStatePath); | ||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("Loaded state!"); | Console.WriteLine("Loaded state!"); | ||||
| @@ -1,11 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | |||||
| namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class QuantizeModel | public class QuantizeModel | ||||
| { | { | ||||
| @@ -13,13 +6,16 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| Console.Write("Please input your original model path: "); | Console.Write("Please input your original model path: "); | ||||
| var inputPath = Console.ReadLine(); | var inputPath = Console.ReadLine(); | ||||
| Console.Write("Please input your output model path: "); | Console.Write("Please input your output model path: "); | ||||
| var outputPath = Console.ReadLine(); | var outputPath = Console.ReadLine(); | ||||
| Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): "); | Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): "); | ||||
| var quantizeType = Console.ReadLine(); | var quantizeType = Console.ReadLine(); | ||||
| if (LLamaQuantizer.Quantize(inputPath, outputPath, quantizeType)) | if (LLamaQuantizer.Quantize(inputPath, outputPath, quantizeType)) | ||||
| { | { | ||||
| Console.WriteLine("Quantization succeed!"); | |||||
| Console.WriteLine("Quantization succeeded!"); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -0,0 +1,69 @@ | |||||
| using System.Reflection.Metadata; | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | |||||
| using Microsoft.SemanticKernel; | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | |||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| using LLamaSharp.SemanticKernel.ChatCompletion; | |||||
| using LLamaSharp.SemanticKernel.TextCompletion; | |||||
| namespace LLama.Examples.NewVersion | |||||
| { | |||||
| public class SemanticKernelChat | |||||
| { | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.WriteLine("Example from: https://github.com/microsoft/semantic-kernel/blob/main/dotnet/README.md"); | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| // Load weights into memory | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue), | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var ex = new InteractiveExecutor(context); | |||||
| var chatGPT = new LLamaSharpChatCompletion(ex); | |||||
| var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books"); | |||||
| Console.WriteLine("Chat content:"); | |||||
| Console.WriteLine("------------------------"); | |||||
| chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); | |||||
| await MessageOutputAsync(chatHistory); | |||||
| // First bot assistant message | |||||
| string reply = await chatGPT.GenerateMessageAsync(chatHistory); | |||||
| chatHistory.AddAssistantMessage(reply); | |||||
| await MessageOutputAsync(chatHistory); | |||||
| // Second user message | |||||
| chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); | |||||
| await MessageOutputAsync(chatHistory); | |||||
| // Second bot assistant message | |||||
| reply = await chatGPT.GenerateMessageAsync(chatHistory); | |||||
| chatHistory.AddAssistantMessage(reply); | |||||
| await MessageOutputAsync(chatHistory); | |||||
| } | |||||
| /// <summary> | |||||
| /// Outputs the last message of the chat history | |||||
| /// </summary> | |||||
| private static Task MessageOutputAsync(Microsoft.SemanticKernel.AI.ChatCompletion.ChatHistory chatHistory) | |||||
| { | |||||
| var message = chatHistory.Messages.Last(); | |||||
| Console.WriteLine($"{message.Role}: {message.Content}"); | |||||
| Console.WriteLine("------------------------"); | |||||
| return Task.CompletedTask; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,55 @@ | |||||
| using System.Reflection.Metadata; | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | |||||
| using Microsoft.SemanticKernel; | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | |||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| using LLamaSharp.SemanticKernel.TextCompletion; | |||||
| namespace LLama.Examples.NewVersion | |||||
| { | |||||
| public class SemanticKernelPrompt | |||||
| { | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.WriteLine("Example from: https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/KernelSyntaxExamples/Example17_ChatGPT.cs"); | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| // Load weights into memory | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue), | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | |||||
| var builder = new KernelBuilder(); | |||||
| builder.WithAIService<ITextCompletion>("local-llama", new LLamaSharpTextCompletion(ex), true); | |||||
| var kernel = builder.Build(); | |||||
| var prompt = @"{{$input}} | |||||
| One line TLDR with the fewest words."; | |||||
| var summarize = kernel.CreateSemanticFunction(prompt, maxTokens: 100); | |||||
| string text1 = @" | |||||
| 1st Law of Thermodynamics - Energy cannot be created or destroyed. | |||||
| 2nd Law of Thermodynamics - For a spontaneous process, the entropy of the universe increases. | |||||
| 3rd Law of Thermodynamics - A perfect crystal at zero Kelvin has zero entropy."; | |||||
| string text2 = @" | |||||
| 1. An object at rest remains at rest, and an object in motion remains in motion at constant speed and in a straight line unless acted on by an unbalanced force. | |||||
| 2. The acceleration of an object depends on the mass of the object and the amount of force applied. | |||||
| 3. Whenever one object exerts a force on another object, the second object exerts an equal and opposite on the first."; | |||||
| Console.WriteLine(await summarize.InvokeAsync(text1)); | |||||
| Console.WriteLine(await summarize.InvokeAsync(text2)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,9 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -12,9 +7,16 @@ namespace LLama.Examples.NewVersion | |||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| string modelPath = Console.ReadLine(); | |||||
| var modelPath = Console.ReadLine(); | |||||
| StatelessExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); | |||||
| var parameters = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 1024, | |||||
| Seed = 1337, | |||||
| GpuLayerCount = 5 | |||||
| }; | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + | Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + | ||||
| @@ -29,10 +31,10 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| Console.Write("\nQuestion: "); | Console.Write("\nQuestion: "); | ||||
| Console.ForegroundColor = ConsoleColor.Green; | Console.ForegroundColor = ConsoleColor.Green; | ||||
| string prompt = Console.ReadLine(); | |||||
| Console.ForegroundColor = ConsoleColor.White; | |||||
| var prompt = Console.ReadLine(); | |||||
| Console.ForegroundColor = ConsoleColor.White; | |||||
| Console.Write("Answer: "); | Console.Write("Answer: "); | ||||
| prompt = $"Question: {prompt.Trim()} Answer: "; | |||||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | |||||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | foreach (var text in ex.Infer(prompt, inferenceParams)) | ||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| @@ -0,0 +1,74 @@ | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | |||||
| namespace LLama.Examples.NewVersion | |||||
| { | |||||
| public class TalkToYourself | |||||
| { | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| // Load weights into memory | |||||
| var @params = new ModelParams(modelPath) | |||||
| { | |||||
| Seed = RandomNumberGenerator.GetInt32(int.MaxValue) | |||||
| }; | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| // Create 2 contexts sharing the same weights | |||||
| using var aliceCtx = weights.CreateContext(@params); | |||||
| var alice = new InteractiveExecutor(aliceCtx); | |||||
| using var bobCtx = weights.CreateContext(@params); | |||||
| var bob = new InteractiveExecutor(bobCtx); | |||||
| // Initial alice prompt | |||||
| var alicePrompt = "Transcript of a dialog, where the Alice interacts a person named Bob. Alice is friendly, kind, honest and good at writing.\nAlice: Hello"; | |||||
| var aliceResponse = await Prompt(alice, ConsoleColor.Green, alicePrompt, false, false); | |||||
| // Initial bob prompt | |||||
| var bobPrompt = $"Transcript of a dialog, where the Bob interacts a person named Alice. Bob is smart, intellectual and good at writing.\nAlice: Hello{aliceResponse}"; | |||||
| var bobResponse = await Prompt(bob, ConsoleColor.Red, bobPrompt, true, true); | |||||
| // swap back and forth from Alice to Bob | |||||
| while (true) | |||||
| { | |||||
| aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true); | |||||
| bobResponse = await Prompt(bob, ConsoleColor.Red, aliceResponse, false, true); | |||||
| if (Console.KeyAvailable) | |||||
| break; | |||||
| } | |||||
| } | |||||
| private static async Task<string> Prompt(ILLamaExecutor executor, ConsoleColor color, string prompt, bool showPrompt, bool showResponse) | |||||
| { | |||||
| var inferenceParams = new InferenceParams | |||||
| { | |||||
| Temperature = 0.9f, | |||||
| AntiPrompts = new List<string> { "Alice:", "Bob:", "User:" }, | |||||
| MaxTokens = 128, | |||||
| Mirostat = MirostatType.Mirostat2, | |||||
| MirostatTau = 10, | |||||
| }; | |||||
| Console.ForegroundColor = ConsoleColor.White; | |||||
| if (showPrompt) | |||||
| Console.Write(prompt); | |||||
| Console.ForegroundColor = color; | |||||
| var builder = new StringBuilder(); | |||||
| await foreach (var text in executor.InferAsync(prompt, inferenceParams)) | |||||
| { | |||||
| builder.Append(text); | |||||
| if (showResponse) | |||||
| Console.Write(text); | |||||
| } | |||||
| return builder.ToString(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,10 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples.NewVersion | |||||
| namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class NewVersionTestRunner | public class NewVersionTestRunner | ||||
| { | { | ||||
| @@ -14,7 +8,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.WriteLine("Please input a number to choose an example to run:"); | Console.WriteLine("Please input a number to choose an example to run:"); | ||||
| Console.WriteLine("0: Run a chat session without stripping the role names."); | Console.WriteLine("0: Run a chat session without stripping the role names."); | ||||
| Console.WriteLine("1: Run a chat session with the role names strippped."); | |||||
| Console.WriteLine("1: Run a chat session with the role names stripped."); | |||||
| Console.WriteLine("2: Interactive mode chat by using executor."); | Console.WriteLine("2: Interactive mode chat by using executor."); | ||||
| Console.WriteLine("3: Instruct mode chat by using executor."); | Console.WriteLine("3: Instruct mode chat by using executor."); | ||||
| Console.WriteLine("4: Stateless mode chat by using executor."); | Console.WriteLine("4: Stateless mode chat by using executor."); | ||||
| @@ -22,6 +16,10 @@ namespace LLama.Examples.NewVersion | |||||
| Console.WriteLine("6: Load and save state of model and executor."); | Console.WriteLine("6: Load and save state of model and executor."); | ||||
| Console.WriteLine("7: Get embeddings from LLama model."); | Console.WriteLine("7: Get embeddings from LLama model."); | ||||
| Console.WriteLine("8: Quantize the model."); | Console.WriteLine("8: Quantize the model."); | ||||
| Console.WriteLine("9: Automatic conversation."); | |||||
| Console.WriteLine("10: Constrain response to json format using grammar."); | |||||
| Console.WriteLine("11: Semantic Kernel Prompt."); | |||||
| Console.WriteLine("12: Semantic Kernel Chat."); | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| @@ -64,6 +62,22 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| QuantizeModel.Run(); | QuantizeModel.Run(); | ||||
| } | } | ||||
| else if (choice == 9) | |||||
| { | |||||
| await TalkToYourself.Run(); | |||||
| } | |||||
| else if (choice == 10) | |||||
| { | |||||
| GrammarJsonResponse.Run(); | |||||
| } | |||||
| else if (choice == 11) | |||||
| { | |||||
| await SemanticKernelPrompt.Run(); | |||||
| } | |||||
| else if (choice == 12) | |||||
| { | |||||
| await SemanticKernelChat.Run(); | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| Console.WriteLine("Cannot parse your choice. Please select again."); | Console.WriteLine("Cannot parse your choice. Please select again."); | ||||
| @@ -7,6 +7,7 @@ using LLama.OldVersion; | |||||
| namespace LLama.Examples.Old | namespace LLama.Examples.Old | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class ChatSession | public class ChatSession | ||||
| { | { | ||||
| LLama.OldVersion.ChatSession<LLama.OldVersion.LLamaModel> _session; | LLama.OldVersion.ChatSession<LLama.OldVersion.LLamaModel> _session; | ||||
| @@ -7,6 +7,7 @@ using LLama.OldVersion; | |||||
| namespace LLama.Examples.Old | namespace LLama.Examples.Old | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class ChatWithLLamaModel | public class ChatWithLLamaModel | ||||
| { | { | ||||
| LLama.OldVersion.LLamaModel _model; | LLama.OldVersion.LLamaModel _model; | ||||
| @@ -7,6 +7,7 @@ using LLama.OldVersion; | |||||
| namespace LLama.Examples.Old | namespace LLama.Examples.Old | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class GetEmbeddings | public class GetEmbeddings | ||||
| { | { | ||||
| LLama.OldVersion.LLamaEmbedder _embedder; | LLama.OldVersion.LLamaEmbedder _embedder; | ||||
| @@ -7,6 +7,7 @@ using LLama.OldVersion; | |||||
| namespace LLama.Examples.Old | namespace LLama.Examples.Old | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class InstructMode | public class InstructMode | ||||
| { | { | ||||
| LLama.OldVersion.LLamaModel _model; | LLama.OldVersion.LLamaModel _model; | ||||
| @@ -7,6 +7,7 @@ using LLama.OldVersion; | |||||
| namespace LLama.Examples.Old | namespace LLama.Examples.Old | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class SaveAndLoadState: IDisposable | public class SaveAndLoadState: IDisposable | ||||
| { | { | ||||
| LLama.OldVersion.LLamaModel _model; | LLama.OldVersion.LLamaModel _model; | ||||
| @@ -1,7 +1,4 @@ | |||||
| using LLama; | |||||
| using LLama.Common; | |||||
| using LLama.Examples; | |||||
| using LLama.Examples.NewVersion; | |||||
| using LLama.Examples.NewVersion; | |||||
| using LLama.Examples.Old; | using LLama.Examples.Old; | ||||
| Console.WriteLine("======================================================================================================"); | Console.WriteLine("======================================================================================================"); | ||||
| @@ -0,0 +1,17 @@ | |||||
| using static LLama.LLamaTransforms; | |||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||||
| /// <summary> | |||||
| /// Default HistoryTransform Patch | |||||
| /// </summary> | |||||
| public class HistoryTransform : DefaultHistoryTransform | |||||
| { | |||||
| /// <inheritdoc/> | |||||
| public override string HistoryToText(global::LLama.Common.ChatHistory history) | |||||
| { | |||||
| var prompt = base.HistoryToText(history); | |||||
| return prompt + "\nAssistant:"; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,74 @@ | |||||
| using LLama; | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| using System.Threading; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||||
| /// <summary> | |||||
| /// LLamaSharp ChatCompletion | |||||
| /// </summary> | |||||
| public sealed class LLamaSharpChatCompletion : IChatCompletion | |||||
| { | |||||
| private const string UserRole = "user:"; | |||||
| private const string AssistantRole = "assistant:"; | |||||
| private ChatSession session; | |||||
| public LLamaSharpChatCompletion(InteractiveExecutor model) | |||||
| { | |||||
| this.session = new ChatSession(model) | |||||
| .WithHistoryTransform(new HistoryTransform()) | |||||
| .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole })); | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public ChatHistory CreateNewChat(string? instructions = "") | |||||
| { | |||||
| var history = new ChatHistory(); | |||||
| if (instructions != null && !string.IsNullOrEmpty(instructions)) | |||||
| { | |||||
| history.AddSystemMessage(instructions); | |||||
| } | |||||
| return history; | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public async Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, ChatRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) | |||||
| { | |||||
| requestSettings ??= new ChatRequestSettings() | |||||
| { | |||||
| MaxTokens = 256, | |||||
| Temperature = 0, | |||||
| TopP = 0, | |||||
| StopSequences = new List<string> { } | |||||
| }; | |||||
| var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); | |||||
| return new List<IChatResult> { new LLamaSharpChatResult(result) }.AsReadOnly(); | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsAsync(ChatHistory chat, ChatRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| requestSettings ??= new ChatRequestSettings() | |||||
| { | |||||
| MaxTokens = 256, | |||||
| Temperature = 0, | |||||
| TopP = 0, | |||||
| StopSequences = new List<string> { } | |||||
| }; | |||||
| var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); | |||||
| yield return new LLamaSharpChatResult(result); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | |||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||||
| /// <summary> | |||||
| /// LLamaSharp Chat Message | |||||
| /// </summary> | |||||
| public class LLamaSharpChatMessage : ChatMessageBase | |||||
| { | |||||
| /// <inheritdoc/> | |||||
| public LLamaSharpChatMessage(AuthorRole role, string content) : base(role, content) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,38 @@ | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||||
| internal sealed class LLamaSharpChatResult : IChatStreamingResult | |||||
| { | |||||
| private readonly IAsyncEnumerable<string> _stream; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="stream"></param> | |||||
| public LLamaSharpChatResult(IAsyncEnumerable<string> stream) | |||||
| { | |||||
| _stream = stream; | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public async Task<ChatMessageBase> GetChatMessageAsync(CancellationToken cancellationToken = default) | |||||
| { | |||||
| var sb = new StringBuilder(); | |||||
| await foreach (var token in _stream) | |||||
| { | |||||
| sb.Append(token); | |||||
| } | |||||
| return await Task.FromResult(new LLamaSharpChatMessage(AuthorRole.Assistant, sb.ToString())).ConfigureAwait(false); | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public async IAsyncEnumerable<ChatMessageBase> GetStreamingChatMessageAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| await foreach (var token in _stream) | |||||
| { | |||||
| yield return new LLamaSharpChatMessage(AuthorRole.Assistant, token); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using Microsoft.SemanticKernel.AI.ChatCompletion; | |||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| namespace LLamaSharp.SemanticKernel; | |||||
| internal static class ExtensionMethods | |||||
| { | |||||
| internal static global::LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory chatHistory) | |||||
| { | |||||
| if (chatHistory is null) | |||||
| { | |||||
| throw new ArgumentNullException(nameof(chatHistory)); | |||||
| } | |||||
| var history = new global::LLama.Common.ChatHistory(); | |||||
| foreach (var chat in chatHistory) | |||||
| { | |||||
| var role = Enum.TryParse<global::LLama.Common.AuthorRole>(chat.Role.Label, out var _role) ? _role : global::LLama.Common.AuthorRole.Unknown; | |||||
| history.AddMessage(role, chat.Content); | |||||
| } | |||||
| return history; | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert ChatRequestSettings to LLamaSharp InferenceParams | |||||
| /// </summary> | |||||
| /// <param name="requestSettings"></param> | |||||
| /// <returns></returns> | |||||
| internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings) | |||||
| { | |||||
| if (requestSettings is null) | |||||
| { | |||||
| throw new ArgumentNullException(nameof(requestSettings)); | |||||
| } | |||||
| var antiPrompts = new List<string>(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" }; | |||||
| return new global::LLama.Common.InferenceParams | |||||
| { | |||||
| Temperature = (float)requestSettings.Temperature, | |||||
| TopP = (float)requestSettings.TopP, | |||||
| PresencePenalty = (float)requestSettings.PresencePenalty, | |||||
| FrequencyPenalty = (float)requestSettings.FrequencyPenalty, | |||||
| AntiPrompts = antiPrompts, | |||||
| MaxTokens = requestSettings.MaxTokens ?? -1 | |||||
| }; | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert CompleteRequestSettings to LLamaSharp InferenceParams | |||||
| /// </summary> | |||||
| /// <param name="requestSettings"></param> | |||||
| /// <returns></returns> | |||||
| internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this CompleteRequestSettings requestSettings) | |||||
| { | |||||
| if (requestSettings is null) | |||||
| { | |||||
| throw new ArgumentNullException(nameof(requestSettings)); | |||||
| } | |||||
| return new global::LLama.Common.InferenceParams | |||||
| { | |||||
| Temperature = (float)requestSettings.Temperature, | |||||
| TopP = (float)requestSettings.TopP, | |||||
| PresencePenalty = (float)requestSettings.PresencePenalty, | |||||
| FrequencyPenalty = (float)requestSettings.FrequencyPenalty, | |||||
| AntiPrompts = requestSettings.StopSequences, | |||||
| MaxTokens = requestSettings.MaxTokens ?? -1 | |||||
| }; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,22 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | |||||
| <PropertyGroup> | |||||
| <TargetFrameworks>netstandard2.0;net6.0;net7.0</TargetFrameworks> | |||||
| <RootNamespace>LLamaSharp.SemanticKernel</RootNamespace> | |||||
| <Nullable>enable</Nullable> | |||||
| <LangVersion>10</LangVersion> | |||||
| <Platforms>AnyCPU;x64;Arm64</Platforms> | |||||
| <AllowUnsafeBlocks>True</AllowUnsafeBlocks> | |||||
| <ImplicitUsings>enable</ImplicitUsings> | |||||
| <Nullable>enable</Nullable> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | |||||
| <PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="0.21.230828.2-preview" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\LLama\LLamaSharp.csproj" /> | |||||
| </ItemGroup> | |||||
| </Project> | |||||
| @@ -0,0 +1,26 @@ | |||||
| # LLamaSharp.SemanticKernel | |||||
| LLamaSharp.SemanticKernel are connections for [SemanticKernel](https://github.com/microsoft/semantic-kernel): an SDK for intergrating various LLM interfaces into a single implementation. With this, you can add local LLaMa queries as another connection point with your existing connections. | |||||
| For reference on how to implement it, view the following examples: | |||||
| - [SemanticKernelChat](../LLama.Examples/NewVersion/SemanticKernelChat.cs) | |||||
| - [SemanticKernelPrompt](../LLama.Examples/NewVersion/SemanticKernelPrompt.cs) | |||||
| ## ITextCompletion | |||||
| ```csharp | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| // LLamaSharpTextCompletion can accept ILLamaExecutor. | |||||
| var ex = new StatelessExecutor(model, parameters); | |||||
| var builder = new KernelBuilder(); | |||||
| builder.WithAIService<ITextCompletion>("local-llama", new LLamaSharpTextCompletion(ex), true); | |||||
| ``` | |||||
| ## IChatCompletion | |||||
| ```csharp | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var context = model.CreateContext(parameters); | |||||
| // LLamaSharpChatCompletion requires InteractiveExecutor, as it's the best fit for the given command. | |||||
| var ex = new InteractiveExecutor(context); | |||||
| var chatGPT = new LLamaSharpChatCompletion(ex); | |||||
| ``` | |||||
| @@ -0,0 +1,27 @@ | |||||
| using LLama; | |||||
| using LLama.Abstractions; | |||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| namespace LLamaSharp.SemanticKernel.TextCompletion; | |||||
| public sealed class LLamaSharpTextCompletion : ITextCompletion | |||||
| { | |||||
| public ILLamaExecutor executor; | |||||
| public LLamaSharpTextCompletion(ILLamaExecutor executor) | |||||
| { | |||||
| this.executor = executor; | |||||
| } | |||||
| public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, CompleteRequestSettings requestSettings, CancellationToken cancellationToken = default) | |||||
| { | |||||
| var result = executor.InferAsync(text, requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); | |||||
| return await Task.FromResult(new List<ITextResult> { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false); | |||||
| } | |||||
| public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync(string text, CompleteRequestSettings requestSettings, CancellationToken cancellationToken = default) | |||||
| { | |||||
| var result = executor.InferAsync(text, requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); | |||||
| yield return new LLamaTextResult(result); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,37 @@ | |||||
| using Microsoft.SemanticKernel.AI.TextCompletion; | |||||
| using Microsoft.SemanticKernel.Orchestration; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| namespace LLamaSharp.SemanticKernel.TextCompletion; | |||||
| internal sealed class LLamaTextResult : ITextStreamingResult | |||||
| { | |||||
| private readonly IAsyncEnumerable<string> _text; | |||||
| public LLamaTextResult(IAsyncEnumerable<string> text) | |||||
| { | |||||
| _text = text; | |||||
| ModelResult = new(text); | |||||
| } | |||||
| public ModelResult ModelResult { get; } | |||||
| public async Task<string> GetCompletionAsync(CancellationToken cancellationToken = default) | |||||
| { | |||||
| var sb = new StringBuilder(); | |||||
| await foreach (var token in _text) | |||||
| { | |||||
| sb.Append(token); | |||||
| } | |||||
| return await Task.FromResult(sb.ToString()).ConfigureAwait(false); | |||||
| } | |||||
| public async IAsyncEnumerable<string> GetCompletionStreamingAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| await foreach (string word in _text) | |||||
| { | |||||
| yield return word; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,15 +1,60 @@ | |||||
| using LLama; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| { | { | ||||
| public class BasicTest | public class BasicTest | ||||
| : IDisposable | |||||
| { | { | ||||
| private readonly ModelParams _params; | |||||
| private readonly LLamaWeights _model; | |||||
| public BasicTest() | |||||
| { | |||||
| _params = new ModelParams(Constants.ModelPath) | |||||
| { | |||||
| ContextSize = 2048 | |||||
| }; | |||||
| _model = LLamaWeights.LoadFromFile(_params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _model.Dispose(); | |||||
| } | |||||
| [Fact] | [Fact] | ||||
| public void LoadModel() | |||||
| public void BasicModelProperties() | |||||
| { | { | ||||
| var model = new LLamaModel(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256)); | |||||
| model.Dispose(); | |||||
| Assert.Equal(32000, _model.VocabCount); | |||||
| Assert.Equal(2048, _model.ContextSize); | |||||
| Assert.Equal(4096, _model.EmbeddingSize); | |||||
| } | |||||
| [Fact] | |||||
| public void CloneContext() | |||||
| { | |||||
| var original = _model.CreateContext(_params); | |||||
| // Evaluate something (doesn't matter what, as long as it begins with token 1) | |||||
| original.Eval(new[] { 1, 42, 321 }, 0); | |||||
| // Clone current state | |||||
| var clone = original.Clone(); | |||||
| // Now evaluate something more | |||||
| var reply1a = original.Eval(new[] { 4, 5, 6 }, 3); | |||||
| var reply2a = original.Eval(new[] { 7, 8, 9 }, 6); | |||||
| // Assert that the context replied differently each time | |||||
| Assert.NotEqual(reply1a, reply2a); | |||||
| // Give the same prompts to the cloned state | |||||
| var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3); | |||||
| var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6); | |||||
| // Assert that the cloned context replied in the same way as originally | |||||
| Assert.Equal(reply1a, reply1b); | |||||
| Assert.Equal(reply2a, reply2b); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,7 @@ | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| internal static class Constants | |||||
| { | |||||
| public static string ModelPath = "Models/llama-2-7b.q4_0.gguf"; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,241 @@ | |||||
| using LLama.Native; | |||||
| using LLama.Grammars; | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| /// <summary> | |||||
| /// Source: | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp | |||||
| /// | |||||
| /// The commit hash from URL is the actual commit hash that reflects current C# code. | |||||
| /// </summary> | |||||
| public sealed class GrammarParserTest | |||||
| { | |||||
| [Fact] | |||||
| public void ParseComplexGrammar() | |||||
| { | |||||
| GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); | |||||
| string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ | |||||
| expr ::= term ([-+*/] term)* | |||||
| term ::= [0-9]+"; | |||||
| var state = parsedGrammar.Parse(grammarBytes, "root"); | |||||
| Assert.Equal(0ul, state.StartRuleIndex); | |||||
| var expected = new List<KeyValuePair<string, uint>> | |||||
| { | |||||
| new KeyValuePair<string, uint>("expr", 2), | |||||
| new KeyValuePair<string, uint>("expr_5", 5), | |||||
| new KeyValuePair<string, uint>("expr_6", 6), | |||||
| new KeyValuePair<string, uint>("root", 0), | |||||
| new KeyValuePair<string, uint>("root_1", 1), | |||||
| new KeyValuePair<string, uint>("root_4", 4), | |||||
| new KeyValuePair<string, uint>("term", 3), | |||||
| new KeyValuePair<string, uint>("term_7", 7), | |||||
| }; | |||||
| foreach (var symbol in expected) | |||||
| { | |||||
| var rule = state.Rules[(int)symbol.Value]; | |||||
| Assert.Equal(symbol.Key, rule.Name); | |||||
| } | |||||
| var expectedRules = new List<LLamaGrammarElement> | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }; | |||||
| uint index = 0; | |||||
| foreach (var rule in state.Rules) | |||||
| { | |||||
| // compare rule to expected rule | |||||
| for (uint i = 0; i < rule.Elements.Count; i++) | |||||
| { | |||||
| var element = rule.Elements[(int)i]; | |||||
| var expectedElement = expectedRules[(int)index]; | |||||
| // Pretty print error message before asserting | |||||
| if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) | |||||
| { | |||||
| Console.Error.WriteLine($"index: {index}"); | |||||
| Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); | |||||
| Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); | |||||
| Console.Error.WriteLine("expected_element != actual_element"); | |||||
| } | |||||
| Assert.Equal(expectedElement.Type, element.Type); | |||||
| Assert.Equal(expectedElement.Value, element.Value); | |||||
| index++; | |||||
| } | |||||
| } | |||||
| Assert.NotEmpty(state.Rules); | |||||
| } | |||||
| [Fact] | |||||
| public void ParseExtraComplexGrammar() | |||||
| { | |||||
| GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); | |||||
| string grammarBytes = @" | |||||
| root ::= (expr ""="" ws term ""\n"")+ | |||||
| expr ::= term ([-+*/] term)* | |||||
| term ::= ident | num | ""("" ws expr "")"" ws | |||||
| ident ::= [a-z] [a-z0-9_]* ws | |||||
| num ::= [0-9]+ ws | |||||
| ws ::= [ \t\n]* | |||||
| "; | |||||
| var state = parsedGrammar.Parse(grammarBytes, "root"); | |||||
| Assert.Equal(0ul, state.StartRuleIndex); | |||||
| var expected = new List<KeyValuePair<string, uint>> | |||||
| { | |||||
| new KeyValuePair<string, uint>("expr", 2), | |||||
| new KeyValuePair<string, uint>("expr_6", 6), | |||||
| new KeyValuePair<string, uint>("expr_7", 7), | |||||
| new KeyValuePair<string, uint>("ident", 8), | |||||
| new KeyValuePair<string, uint>("ident_10", 10), | |||||
| new KeyValuePair<string, uint>("num", 9), | |||||
| new KeyValuePair<string, uint>("num_11", 11), | |||||
| new KeyValuePair<string, uint>("root", 0), | |||||
| new KeyValuePair<string, uint>("root_1", 1), | |||||
| new KeyValuePair<string, uint>("root_5", 5), | |||||
| new KeyValuePair<string, uint>("term", 4), | |||||
| new KeyValuePair<string, uint>("ws", 3), | |||||
| new KeyValuePair<string, uint>("ws_12", 12), | |||||
| }; | |||||
| foreach (var symbol in expected) | |||||
| { | |||||
| var rule = state.Rules[(int)symbol.Value]; | |||||
| Assert.Equal(symbol.Key, rule.Name); | |||||
| } | |||||
| var expectedRules = new List<LLamaGrammarElement> | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0) | |||||
| }; | |||||
| uint index = 0; | |||||
| foreach (var rule in state.Rules) | |||||
| { | |||||
| // compare rule to expected rule | |||||
| for (uint i = 0; i < rule.Elements.Count; i++) | |||||
| { | |||||
| var element = rule.Elements[(int)i]; | |||||
| var expectedElement = expectedRules[(int)index]; | |||||
| // Pretty print error message before asserting | |||||
| if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) | |||||
| { | |||||
| Console.Error.WriteLine($"index: {index}"); | |||||
| Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); | |||||
| Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); | |||||
| Console.Error.WriteLine("expected_element != actual_element"); | |||||
| } | |||||
| Assert.Equal(expectedElement.Type, element.Type); | |||||
| Assert.Equal(expectedElement.Value, element.Value); | |||||
| index++; | |||||
| } | |||||
| } | |||||
| Assert.NotEmpty(state.Rules); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,74 @@ | |||||
| using LLama.Common; | |||||
| using LLama.Grammars; | |||||
| using LLama.Native; | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| public sealed class GrammarTest | |||||
| : IDisposable | |||||
| { | |||||
| private readonly ModelParams _params; | |||||
| private readonly LLamaWeights _model; | |||||
| public GrammarTest() | |||||
| { | |||||
| _params = new ModelParams(Constants.ModelPath) | |||||
| { | |||||
| ContextSize = 2048, | |||||
| }; | |||||
| _model = LLamaWeights.LoadFromFile(_params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _model.Dispose(); | |||||
| } | |||||
| [Fact] | |||||
| public void CreateBasicGrammar() | |||||
| { | |||||
| var rules = new List<GrammarRule> | |||||
| { | |||||
| new GrammarRule("alpha", new[] | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }), | |||||
| }; | |||||
| using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | |||||
| } | |||||
| [Fact] | |||||
| public void SampleWithTrivialGrammar() | |||||
| { | |||||
| // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | |||||
| // we can be confident it's not what the LLM would say if not constrained by the grammar! | |||||
| var rules = new List<GrammarRule> | |||||
| { | |||||
| new GrammarRule("feline", new [] | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }), | |||||
| }; | |||||
| using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | |||||
| var executor = new StatelessExecutor(_model, _params); | |||||
| var inferenceParams = new InferenceParams | |||||
| { | |||||
| MaxTokens = 3, | |||||
| AntiPrompts = new [] { ".", "Input:", "\n" }, | |||||
| Grammar = grammar, | |||||
| }; | |||||
| var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); | |||||
| Assert.Equal("cat", result[0]); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -11,20 +11,20 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||||
| <PackageReference Include="xunit" Version="2.4.2" /> | |||||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.4.5"> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> | |||||
| <PackageReference Include="xunit" Version="2.5.0" /> | |||||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> | |||||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
| <PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
| </PackageReference> | </PackageReference> | ||||
| <PackageReference Include="coverlet.collector" Version="3.1.2"> | |||||
| <PackageReference Include="coverlet.collector" Version="6.0.0"> | |||||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
| <PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
| </PackageReference> | </PackageReference> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <Target Name="DownloadContentFiles" BeforeTargets="Build"> | <Target Name="DownloadContentFiles" BeforeTargets="Build"> | ||||
| <DownloadFile SourceUrl="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_S.bin" DestinationFolder="Models" DestinationFileName="llama-2-7b-chat.ggmlv3.q3_K_S.bin" SkipUnchangedFiles="true"> | |||||
| <DownloadFile SourceUrl="https://huggingface.co/narrative-bi/Llama-2-7B-GGUF/resolve/main/llama-2-7b.q4_0.gguf" DestinationFolder="Models" DestinationFileName="llama-2-7b.q4_0.gguf" SkipUnchangedFiles="true"> | |||||
| </DownloadFile> | </DownloadFile> | ||||
| </Target> | </Target> | ||||
| @@ -37,7 +37,7 @@ | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <None Update="Models\llama-2-7b-chat.ggmlv3.q3_K_S.bin"> | |||||
| <None Update="Models\llama-2-7b.q4_0.gguf"> | |||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
| </None> | </None> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -0,0 +1,35 @@ | |||||
| using LLama.Common; | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| public class LLamaContextTests | |||||
| : IDisposable | |||||
| { | |||||
| private readonly LLamaWeights _weights; | |||||
| private readonly LLamaContext _context; | |||||
| public LLamaContextTests() | |||||
| { | |||||
| var @params = new ModelParams(Constants.ModelPath) | |||||
| { | |||||
| ContextSize = 768, | |||||
| }; | |||||
| _weights = LLamaWeights.LoadFromFile(@params); | |||||
| _context = _weights.CreateContext(@params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _weights.Dispose(); | |||||
| _context.Dispose(); | |||||
| } | |||||
| [Fact] | |||||
| public void CheckProperties() | |||||
| { | |||||
| Assert.Equal(768, _context.ContextSize); | |||||
| Assert.Equal(4096, _context.EmbeddingSize); | |||||
| Assert.Equal(32000, _context.VocabCount); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,71 @@ | |||||
| using LLama.Common; | |||||
| namespace LLama.Unittest; | |||||
| public class LLamaEmbedderTests | |||||
| : IDisposable | |||||
| { | |||||
| private readonly LLamaEmbedder _embedder = new(new ModelParams(Constants.ModelPath)); | |||||
| public void Dispose() | |||||
| { | |||||
| _embedder.Dispose(); | |||||
| } | |||||
| private static float Magnitude(float[] a) | |||||
| { | |||||
| return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum()); | |||||
| } | |||||
| private static void Normalize(float[] a) | |||||
| { | |||||
| var mag = Magnitude(a); | |||||
| for (var i = 0; i < a.Length; i++) | |||||
| a[i] /= mag; | |||||
| } | |||||
| private static float Dot(float[] a, float[] b) | |||||
| { | |||||
| Assert.Equal(a.Length, b.Length); | |||||
| return a.Zip(b, (x, y) => x * y).Sum(); | |||||
| } | |||||
| private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f) | |||||
| { | |||||
| for (int i = 0; i < expected.Length; i++) | |||||
| Assert.Equal(expected[i], actual[i], epsilon); | |||||
| } | |||||
| // todo: enable this one llama2 7B gguf is available | |||||
| //[Fact] | |||||
| //public void EmbedBasic() | |||||
| //{ | |||||
| // var cat = _embedder.GetEmbeddings("cat"); | |||||
| // Assert.NotNull(cat); | |||||
| // Assert.NotEmpty(cat); | |||||
| // // Expected value generate with llama.cpp embedding.exe | |||||
| // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f }; | |||||
| // AssertApproxStartsWith(expected, cat); | |||||
| //} | |||||
| [Fact] | |||||
| public void EmbedCompare() | |||||
| { | |||||
| var cat = _embedder.GetEmbeddings("cat"); | |||||
| var kitten = _embedder.GetEmbeddings("kitten"); | |||||
| var spoon = _embedder.GetEmbeddings("spoon"); | |||||
| Normalize(cat); | |||||
| Normalize(kitten); | |||||
| Normalize(spoon); | |||||
| var close = Dot(cat, kitten); | |||||
| var far = Dot(cat, spoon); | |||||
| // This comparison seems backwards, but remember that with a | |||||
| // dot product 1.0 means **identical** and 0.0 means **completely opposite**! | |||||
| Assert.True(close > far); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,70 @@ | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using Newtonsoft.Json; | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| public class ModelsParamsTests | |||||
| { | |||||
| [Fact] | |||||
| public void SerializeRoundTripSystemTextJson() | |||||
| { | |||||
| var expected = new ModelParams("abc/123") | |||||
| { | |||||
| BatchSize = 17, | |||||
| ContextSize = 42, | |||||
| LoraAdapter = "adapter", | |||||
| Seed = 42, | |||||
| GpuLayerCount = 111 | |||||
| }; | |||||
| var json = System.Text.Json.JsonSerializer.Serialize(expected); | |||||
| var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json); | |||||
| Assert.Equal(expected, actual); | |||||
| } | |||||
| [Fact] | |||||
| public void SerializeRoundTripNewtonsoft() | |||||
| { | |||||
| var expected = new ModelParams("abc/123") | |||||
| { | |||||
| BatchSize = 17, | |||||
| ContextSize = 42, | |||||
| LoraAdapter = "adapter", | |||||
| Seed = 42, | |||||
| GpuLayerCount = 111 | |||||
| }; | |||||
| var settings = new Newtonsoft.Json.JsonSerializerSettings(); | |||||
| settings.Converters.Add(new NewtsonsoftEncodingConverter()); | |||||
| var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings); | |||||
| var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json, settings); | |||||
| Assert.Equal(expected, actual); | |||||
| } | |||||
| public class NewtsonsoftEncodingConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return typeof(Encoding).IsAssignableFrom(objectType); | |||||
| } | |||||
| public override void WriteJson(JsonWriter writer, object value, JsonSerializer serializer) | |||||
| { | |||||
| writer.WriteValue(((Encoding)value).WebName); | |||||
| } | |||||
| public override object ReadJson(JsonReader reader, Type objectType, object existingValue, JsonSerializer serializer) | |||||
| { | |||||
| return Encoding.GetEncoding((string)reader.Value); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,70 @@ | |||||
| using LLama.Common; | |||||
| using Xunit.Abstractions; | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| public class StatelessExecutorTest | |||||
| : IDisposable | |||||
| { | |||||
| private readonly ITestOutputHelper _testOutputHelper; | |||||
| private readonly LLamaWeights _weights; | |||||
| private readonly ModelParams _params; | |||||
| public StatelessExecutorTest(ITestOutputHelper testOutputHelper) | |||||
| { | |||||
| _testOutputHelper = testOutputHelper; | |||||
| _params = new ModelParams(Constants.ModelPath) | |||||
| { | |||||
| ContextSize = 60, | |||||
| Seed = 1754 | |||||
| }; | |||||
| _weights = LLamaWeights.LoadFromFile(_params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _weights.Dispose(); | |||||
| } | |||||
| [Fact] | |||||
| public void Stateless() | |||||
| { | |||||
| var executor = new StatelessExecutor(_weights, _params); | |||||
| const string question = "Question. what is a cat?\nAnswer: "; | |||||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | |||||
| var result1 = string.Join("", executor.Infer(question, @params)); | |||||
| var result2 = string.Join("", executor.Infer(question, @params)); | |||||
| _testOutputHelper.WriteLine(result1); | |||||
| // Check that it produced the exact same result both times | |||||
| Assert.Equal(result1, result2); | |||||
| } | |||||
| [Fact] | |||||
| public void OutOfContext() | |||||
| { | |||||
| var executor = new StatelessExecutor(_weights, _params); | |||||
| const string question = " Question. why is a cat the best pet?\nAnswer: "; | |||||
| // The context size is set to 60. Generate more than that, forcing it to generate a coherent response | |||||
| // with a modified context | |||||
| var @params = new InferenceParams() | |||||
| { | |||||
| MaxTokens = 100, | |||||
| TokensKeep = question.Length, | |||||
| }; | |||||
| var result1 = string.Join("", executor.Infer(question, @params)); | |||||
| var result2 = string.Join("", executor.Infer(question, @params)); | |||||
| _testOutputHelper.WriteLine(result1); | |||||
| // Check that it produced the exact same result both times | |||||
| Assert.Equal(result1, result2); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,8 +1,10 @@ | |||||
| using LLama.Abstractions; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| public class ModelOptions : IModelParams | |||||
| public class ModelOptions | |||||
| : IModelParams | |||||
| { | { | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| @@ -86,16 +88,6 @@ namespace LLama.Web.Common | |||||
| /// </summary> | /// </summary> | ||||
| public float[] TensorSplits { get; set; } | public float[] TensorSplits { get; set; } | ||||
| /// <summary> | |||||
| /// Grouped-Query Attention | |||||
| /// </summary> | |||||
| public int GroupedQueryAttention { get; set; } = 1; | |||||
| /// <summary> | |||||
| /// RMS Norm Epsilon | |||||
| /// </summary> | |||||
| public float RmsNormEpsilon { get; set; } = 5e-6f; | |||||
| /// <summary> | /// <summary> | ||||
| /// RoPE base frequency | /// RoPE base frequency | ||||
| /// </summary> | /// </summary> | ||||
| @@ -111,5 +103,9 @@ namespace LLama.Web.Common | |||||
| /// </summary> | /// </summary> | ||||
| public bool MulMatQ { get; set; } | public bool MulMatQ { get; set; } | ||||
| } | |||||
| /// <summary> | |||||
| /// The encoding to use for models | |||||
| /// </summary> | |||||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||||
| } | |||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Native; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| @@ -95,5 +96,10 @@ namespace LLama.Web.Common | |||||
| /// consider newlines as a repeatable token (penalize_nl) | /// consider newlines as a repeatable token (penalize_nl) | ||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNL { get; set; } = true; | public bool PenalizeNL { get; set; } = true; | ||||
| } | |||||
| /// <summary> | |||||
| /// A grammar to constrain possible tokens | |||||
| /// </summary> | |||||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||||
| } | |||||
| } | } | ||||
| @@ -60,7 +60,8 @@ namespace LLama.Web.Models | |||||
| { | { | ||||
| _inferenceOptions = null; | _inferenceOptions = null; | ||||
| _outputTransform = null; | _outputTransform = null; | ||||
| _executor.Model?.Dispose(); | |||||
| _executor?.Context.Dispose(); | |||||
| _executor = null; | _executor = null; | ||||
| } | } | ||||
| } | } | ||||
| @@ -51,7 +51,7 @@ namespace LLama.Web.Services | |||||
| return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached")); | return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached")); | ||||
| // Create model | // Create model | ||||
| var llamaModel = new LLamaModel(modelOption); | |||||
| var llamaModel = new LLamaContext(modelOption); | |||||
| // Create executor | // Create executor | ||||
| ILLamaExecutor executor = executorType switch | ILLamaExecutor executor = executorType switch | ||||
| @@ -8,7 +8,7 @@ namespace LLama.WebAPI.Services; | |||||
| public class StatefulChatService : IDisposable | public class StatefulChatService : IDisposable | ||||
| { | { | ||||
| private readonly ChatSession _session; | private readonly ChatSession _session; | ||||
| private readonly LLamaModel _model; | |||||
| private readonly LLamaContext _context; | |||||
| private bool _continue = false; | private bool _continue = false; | ||||
| private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" | private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" | ||||
| @@ -16,13 +16,16 @@ public class StatefulChatService : IDisposable | |||||
| public StatefulChatService(IConfiguration configuration) | public StatefulChatService(IConfiguration configuration) | ||||
| { | { | ||||
| _model = new LLamaModel(new Common.ModelParams(configuration["ModelPath"], contextSize: 512)); | |||||
| _session = new ChatSession(new InteractiveExecutor(_model)); | |||||
| _context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"]) | |||||
| { | |||||
| ContextSize = 512 | |||||
| }); | |||||
| _session = new ChatSession(new InteractiveExecutor(_context)); | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| _model?.Dispose(); | |||||
| _context?.Dispose(); | |||||
| } | } | ||||
| public string Send(SendMessageInput input) | public string Send(SendMessageInput input) | ||||
| @@ -7,14 +7,17 @@ namespace LLama.WebAPI.Services | |||||
| { | { | ||||
| public class StatelessChatService | public class StatelessChatService | ||||
| { | { | ||||
| private readonly LLamaModel _model; | |||||
| private readonly LLamaContext _context; | |||||
| private readonly ChatSession _session; | private readonly ChatSession _session; | ||||
| public StatelessChatService(IConfiguration configuration) | public StatelessChatService(IConfiguration configuration) | ||||
| { | { | ||||
| _model = new LLamaModel(new ModelParams(configuration["ModelPath"], contextSize: 512)); | |||||
| _context = new LLamaContext(new ModelParams(configuration["ModelPath"]) | |||||
| { | |||||
| ContextSize = 512, | |||||
| }); | |||||
| // TODO: replace with a stateless executor | // TODO: replace with a stateless executor | ||||
| _session = new ChatSession(new InteractiveExecutor(_model)) | |||||
| _session = new ChatSession(new InteractiveExecutor(_context)) | |||||
| .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) | .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) | ||||
| .WithHistoryTransform(new HistoryTransform()); | .WithHistoryTransform(new HistoryTransform()); | ||||
| } | } | ||||
| @@ -1,7 +1,4 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -113,5 +114,10 @@ namespace LLama.Abstractions | |||||
| /// consider newlines as a repeatable token (penalize_nl) | /// consider newlines as a repeatable token (penalize_nl) | ||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNL { get; set; } | public bool PenalizeNL { get; set; } | ||||
| /// <summary> | |||||
| /// Grammar to constrain possible tokens | |||||
| /// </summary> | |||||
| SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,4 @@ | |||||
| using LLama.Common; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Collections.Generic; | |||||
| using System.Threading; | using System.Threading; | ||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| @@ -12,9 +9,9 @@ namespace LLama.Abstractions | |||||
| public interface ILLamaExecutor | public interface ILLamaExecutor | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The loaded model for this executor. | |||||
| /// The loaded context for this executor. | |||||
| /// </summary> | /// </summary> | ||||
| public LLamaModel Model { get; } | |||||
| public LLamaContext Context { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Infers a response from the model. | /// Infers a response from the model. | ||||
| @@ -1,7 +1,10 @@ | |||||
| using System; | |||||
| using System.Text; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| /// <summary> | |||||
| /// The parameters for initializing a LLama model. | |||||
| /// </summary> | |||||
| public interface IModelParams | public interface IModelParams | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -95,16 +98,6 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| float[]? TensorSplits { get; set; } | float[]? TensorSplits { get; set; } | ||||
| /// <summary> | |||||
| /// Grouped-Query Attention | |||||
| /// </summary> | |||||
| int GroupedQueryAttention { get; set; } | |||||
| /// <summary> | |||||
| /// RMS Norm Epsilon | |||||
| /// </summary> | |||||
| float RmsNormEpsilon { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// RoPE base frequency | /// RoPE base frequency | ||||
| /// </summary> | /// </summary> | ||||
| @@ -119,5 +112,10 @@ namespace LLama.Abstractions | |||||
| /// Use experimental mul_mat_q kernels | /// Use experimental mul_mat_q kernels | ||||
| /// </summary> | /// </summary> | ||||
| bool MulMatQ { get; set; } | bool MulMatQ { get; set; } | ||||
| /// <summary> | |||||
| /// The encoding to use for models | |||||
| /// </summary> | |||||
| Encoding Encoding { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Collections.Generic; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -15,6 +13,7 @@ namespace LLama.Abstractions | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IEnumerable<string> Transform(IEnumerable<string> tokens); | IEnumerable<string> Transform(IEnumerable<string> tokens); | ||||
| /// <summary> | /// <summary> | ||||
| /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. | /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,8 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Abstractions | |||||
| namespace LLama.Abstractions | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// An interface for text transformations. | /// An interface for text transformations. | ||||
| @@ -0,0 +1,3 @@ | |||||
| using System.Runtime.CompilerServices; | |||||
| [assembly: InternalsVisibleTo("LLama.Unittest")] | |||||
| @@ -5,6 +5,7 @@ using System.IO; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -13,10 +14,12 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public class ChatSession | public class ChatSession | ||||
| { | { | ||||
| private ILLamaExecutor _executor; | |||||
| private ChatHistory _history; | |||||
| private static readonly string _executorStateFilename = "ExecutorState.json"; | |||||
| private static readonly string _modelStateFilename = "ModelState.st"; | |||||
| private readonly ILLamaExecutor _executor; | |||||
| private readonly ChatHistory _history; | |||||
| private const string _executorStateFilename = "ExecutorState.json"; | |||||
| private const string _modelStateFilename = "ModelState.st"; | |||||
| /// <summary> | /// <summary> | ||||
| /// The executor for this session. | /// The executor for this session. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -91,7 +94,7 @@ namespace LLama | |||||
| { | { | ||||
| Directory.CreateDirectory(path); | Directory.CreateDirectory(path); | ||||
| } | } | ||||
| _executor.Model.SaveState(Path.Combine(path, _modelStateFilename)); | |||||
| _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); | |||||
| if(Executor is StatelessExecutor) | if(Executor is StatelessExecutor) | ||||
| { | { | ||||
| @@ -116,7 +119,7 @@ namespace LLama | |||||
| { | { | ||||
| throw new FileNotFoundException($"Directory {path} does not exist."); | throw new FileNotFoundException($"Directory {path} does not exist."); | ||||
| } | } | ||||
| _executor.Model.LoadState(Path.Combine(path, _modelStateFilename)); | |||||
| _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); | |||||
| if (Executor is StatelessExecutor) | if (Executor is StatelessExecutor) | ||||
| { | { | ||||
| @@ -227,7 +230,7 @@ namespace LLama | |||||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | ||||
| await foreach (var item in OutputTransform.TransformAsync(results)) | |||||
| await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) | |||||
| { | { | ||||
| yield return item; | yield return item; | ||||
| } | } | ||||
| @@ -15,9 +15,20 @@ namespace LLama.Common | |||||
| private readonly int _maxSize; | private readonly int _maxSize; | ||||
| private readonly List<T> _storage; | private readonly List<T> _storage; | ||||
| /// <summary> | |||||
| /// Number of items in this queue | |||||
| /// </summary> | |||||
| public int Count => _storage.Count; | public int Count => _storage.Count; | ||||
| /// <summary> | |||||
| /// Maximum number of items allowed in this queue | |||||
| /// </summary> | |||||
| public int Capacity => _maxSize; | public int Capacity => _maxSize; | ||||
| /// <summary> | |||||
| /// Create a new queue | |||||
| /// </summary> | |||||
| /// <param name="size">the maximum number of items to store in this queue</param> | |||||
| public FixedSizeQueue(int size) | public FixedSizeQueue(int size) | ||||
| { | { | ||||
| _maxSize = size; | _maxSize = size; | ||||
| @@ -1,6 +1,7 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Native; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| @@ -96,6 +97,11 @@ namespace LLama.Common | |||||
| /// consider newlines as a repeatable token (penalize_nl) | /// consider newlines as a repeatable token (penalize_nl) | ||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNL { get; set; } = true; | public bool PenalizeNL { get; set; } = true; | ||||
| /// <summary> | |||||
| /// A grammar to constrain the possible tokens | |||||
| /// </summary> | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,21 +1,44 @@ | |||||
| using System; | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.IO; | using System.IO; | ||||
| using static LLama.Common.ILLamaLogger; | using static LLama.Common.ILLamaLogger; | ||||
| namespace LLama.Common; | namespace LLama.Common; | ||||
| /// <summary> | |||||
| /// receives log messages from LLamaSharp | |||||
| /// </summary> | |||||
| public interface ILLamaLogger | public interface ILLamaLogger | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Severity level of a log message | |||||
| /// </summary> | |||||
| public enum LogLevel | public enum LogLevel | ||||
| { | { | ||||
| Info, | |||||
| Debug, | |||||
| Warning, | |||||
| Error | |||||
| /// <summary> | |||||
| /// Logs that are used for interactive investigation during development. | |||||
| /// </summary> | |||||
| Debug = 1, | |||||
| /// <summary> | |||||
| /// Logs that highlight when the current flow of execution is stopped due to a failure. | |||||
| /// </summary> | |||||
| Error = 2, | |||||
| /// <summary> | |||||
| /// Logs that highlight an abnormal or unexpected event in the application flow, but do not otherwise cause the application execution to stop. | |||||
| /// </summary> | |||||
| Warning = 3, | |||||
| /// <summary> | |||||
| /// Logs that track the general flow of the application. | |||||
| /// </summary> | |||||
| Info = 4 | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Write the log in cosutomized way | |||||
| /// Write the log in customized way | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="source">The source of the log. It may be a method name or class name.</param> | /// <param name="source">The source of the log. It may be a method name or class name.</param> | ||||
| /// <param name="message">The message.</param> | /// <param name="message">The message.</param> | ||||
| @@ -24,19 +47,23 @@ public interface ILLamaLogger | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// The default logger of LLamaSharp. On default it write to console. User methods of `LLamaLogger.Default` to change the behavior. | |||||
| /// It's more recommended to inherit `ILLamaLogger` to cosutomize the behavior. | |||||
| /// The default logger of LLamaSharp. On default it write to console. Use methods of `LLamaLogger.Default` to change the behavior. | |||||
| /// It's recommended to inherit `ILLamaLogger` to customize the behavior. | |||||
| /// </summary> | /// </summary> | ||||
| public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| public sealed class LLamaDefaultLogger | |||||
| : ILLamaLogger | |||||
| { | { | ||||
| private static readonly Lazy<LLamaDefaultLogger> _instance = new Lazy<LLamaDefaultLogger>(() => new LLamaDefaultLogger()); | private static readonly Lazy<LLamaDefaultLogger> _instance = new Lazy<LLamaDefaultLogger>(() => new LLamaDefaultLogger()); | ||||
| private bool _toConsole = true; | private bool _toConsole = true; | ||||
| private bool _toFile = false; | |||||
| private bool _toFile; | |||||
| private FileStream? _fileStream = null; | |||||
| private StreamWriter _fileWriter = null; | |||||
| private FileStream? _fileStream; | |||||
| private StreamWriter? _fileWriter; | |||||
| /// <summary> | |||||
| /// Get the default logger instance | |||||
| /// </summary> | |||||
| public static LLamaDefaultLogger Default => _instance.Value; | public static LLamaDefaultLogger Default => _instance.Value; | ||||
| private LLamaDefaultLogger() | private LLamaDefaultLogger() | ||||
| @@ -44,18 +71,42 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Enable logging output from llama.cpp | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public LLamaDefaultLogger EnableNative() | |||||
| { | |||||
| EnableNativeLogCallback(); | |||||
| return this; | |||||
| } | |||||
| /// <summary> | |||||
| /// Enable writing log messages to console | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public LLamaDefaultLogger EnableConsole() | public LLamaDefaultLogger EnableConsole() | ||||
| { | { | ||||
| _toConsole = true; | _toConsole = true; | ||||
| return this; | return this; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Disable writing messages to console | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public LLamaDefaultLogger DisableConsole() | public LLamaDefaultLogger DisableConsole() | ||||
| { | { | ||||
| _toConsole = false; | _toConsole = false; | ||||
| return this; | return this; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Enable writing log messages to file | |||||
| /// </summary> | |||||
| /// <param name="filename"></param> | |||||
| /// <param name="mode"></param> | |||||
| /// <returns></returns> | |||||
| public LLamaDefaultLogger EnableFile(string filename, FileMode mode = FileMode.Append) | public LLamaDefaultLogger EnableFile(string filename, FileMode mode = FileMode.Append) | ||||
| { | { | ||||
| _fileStream = new FileStream(filename, mode, FileAccess.Write); | _fileStream = new FileStream(filename, mode, FileAccess.Write); | ||||
| @@ -64,7 +115,22 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| return this; | return this; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Disable writing log messages to file | |||||
| /// </summary> | |||||
| /// <param name="filename">unused!</param> | |||||
| /// <returns></returns> | |||||
| [Obsolete("Use DisableFile method without 'filename' parameter")] | |||||
| public LLamaDefaultLogger DisableFile(string filename) | public LLamaDefaultLogger DisableFile(string filename) | ||||
| { | |||||
| return DisableFile(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Disable writing log messages to file | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public LLamaDefaultLogger DisableFile() | |||||
| { | { | ||||
| if (_fileWriter is not null) | if (_fileWriter is not null) | ||||
| { | { | ||||
| @@ -80,6 +146,12 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| return this; | return this; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Log a message | |||||
| /// </summary> | |||||
| /// <param name="source">The source of this message (e.g. class name)</param> | |||||
| /// <param name="message">The message to log</param> | |||||
| /// <param name="level">Severity level of this message</param> | |||||
| public void Log(string source, string message, LogLevel level) | public void Log(string source, string message, LogLevel level) | ||||
| { | { | ||||
| if (level == LogLevel.Info) | if (level == LogLevel.Info) | ||||
| @@ -100,6 +172,10 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Write a log message with "Info" severity | |||||
| /// </summary> | |||||
| /// <param name="message"></param> | |||||
| public void Info(string message) | public void Info(string message) | ||||
| { | { | ||||
| message = MessageFormat("info", message); | message = MessageFormat("info", message); | ||||
| @@ -117,6 +193,10 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Write a log message with "Warn" severity | |||||
| /// </summary> | |||||
| /// <param name="message"></param> | |||||
| public void Warn(string message) | public void Warn(string message) | ||||
| { | { | ||||
| message = MessageFormat("warn", message); | message = MessageFormat("warn", message); | ||||
| @@ -134,6 +214,10 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Write a log message with "Error" severity | |||||
| /// </summary> | |||||
| /// <param name="message"></param> | |||||
| public void Error(string message) | public void Error(string message) | ||||
| { | { | ||||
| message = MessageFormat("error", message); | message = MessageFormat("error", message); | ||||
| @@ -151,10 +235,36 @@ public sealed class LLamaDefaultLogger : ILLamaLogger | |||||
| } | } | ||||
| } | } | ||||
| private string MessageFormat(string level, string message) | |||||
| private static string MessageFormat(string level, string message) | |||||
| { | { | ||||
| DateTime now = DateTime.Now; | |||||
| string formattedDate = now.ToString("yyyy.MM.dd HH:mm:ss"); | |||||
| return $"[{formattedDate}][{level}]: {message}"; | |||||
| var now = DateTime.Now; | |||||
| return $"[{now:yyyy.MM.dd HH:mm:ss}][{level}]: {message}"; | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Register native logging callback | |||||
| /// </summary> | |||||
| private void EnableNativeLogCallback() | |||||
| { | |||||
| // TODO: Move to a more appropriate place once we have a intitialize method | |||||
| NativeApi.llama_log_set(NativeLogCallback); | |||||
| } | |||||
| /// <summary> | |||||
| /// Callback for native logging function | |||||
| /// </summary> | |||||
| /// <param name="level">The log level</param> | |||||
| /// <param name="message">The log message</param> | |||||
| private void NativeLogCallback(LogLevel level, string message) | |||||
| { | |||||
| if (string.IsNullOrEmpty(message)) | |||||
| return; | |||||
| // Note that text includes the new line character at the end for most events. | |||||
| // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it | |||||
| // if it exists. | |||||
| // It might not exist for progress report where '.' is output repeatedly. | |||||
| Log(default!, message.TrimEnd('\n'), level); | |||||
| } | |||||
| } | } | ||||
| @@ -1,12 +1,15 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using System; | using System; | ||||
| using System.Text; | |||||
| using System.Text.Json; | |||||
| using System.Text.Json.Serialization; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The parameters for initializing a LLama model. | /// The parameters for initializing a LLama model. | ||||
| /// </summary> | /// </summary> | ||||
| public class ModelParams | |||||
| public record ModelParams | |||||
| : IModelParams | : IModelParams | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -86,16 +89,6 @@ namespace LLama.Common | |||||
| /// </summary> | /// </summary> | ||||
| public float[]? TensorSplits { get; set; } | public float[]? TensorSplits { get; set; } | ||||
| /// <summary> | |||||
| /// Grouped-Query Attention | |||||
| /// </summary> | |||||
| public int GroupedQueryAttention { get; set; } = 1; | |||||
| /// <summary> | |||||
| /// RMS Norm Epsilon | |||||
| /// </summary> | |||||
| public float RmsNormEpsilon { get; set; } = 5e-6f; | |||||
| /// <summary> | /// <summary> | ||||
| /// RoPE base frequency | /// RoPE base frequency | ||||
| /// </summary> | /// </summary> | ||||
| @@ -111,34 +104,57 @@ namespace LLama.Common | |||||
| /// </summary> | /// </summary> | ||||
| public bool MulMatQ { get; set; } | public bool MulMatQ { get; set; } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="modelPath">The model path.</param> | |||||
| /// <param name="contextSize">Model context size (n_ctx)</param> | |||||
| /// <param name="gpuLayerCount">Number of layers to run in VRAM / GPU memory (n_gpu_layers)</param> | |||||
| /// <param name="seed">Seed for the random number generator (seed)</param> | |||||
| /// <param name="useFp16Memory">Whether to use f16 instead of f32 for memory kv (memory_f16)</param> | |||||
| /// <param name="useMemorymap">Whether to use mmap for faster loads (use_mmap)</param> | |||||
| /// <param name="useMemoryLock">Whether to use mlock to keep model in memory (use_mlock)</param> | |||||
| /// <param name="perplexity">Thether to compute perplexity over the prompt (perplexity)</param> | |||||
| /// <param name="loraAdapter">Lora adapter path (lora_adapter)</param> | |||||
| /// <param name="loraBase">Base model path for the lora adapter (lora_base)</param> | |||||
| /// <param name="threads">Number of threads (-1 = autodetect) (n_threads)</param> | |||||
| /// <param name="batchSize">Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)</param> | |||||
| /// <param name="convertEosToNewLine">Whether to convert eos to newline during the inference.</param> | |||||
| /// <param name="embeddingMode">Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.</param> | |||||
| /// <param name="gqa">Grouped-Query Attention</param> | |||||
| /// <param name="rmsNormEps">RMS Norm Epsilon</param> | |||||
| /// <param name="rope_freq_base">RoPE base frequency.</param> | |||||
| /// <param name="rope_freq_scale">RoPE frequency scaling factor</param> | |||||
| /// <param name="muMatQ">Use experimental mul_mat_q kernels</param> | |||||
| public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, | |||||
| int seed = 1337, bool useFp16Memory = true, | |||||
| bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | |||||
| string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, | |||||
| bool convertEosToNewLine = false, bool embeddingMode = false, | |||||
| int gqa = 1, float rmsNormEps = 5e-6f, float rope_freq_base = 10000.0f, float rope_freq_scale = 1f, bool muMatQ = false) | |||||
| /// <summary> | |||||
| /// The encoding to use to convert text for the model | |||||
| /// </summary> | |||||
| [JsonConverter(typeof(EncodingConverter))] | |||||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="modelPath">The model path.</param> | |||||
| [JsonConstructor] | |||||
| public ModelParams(string modelPath) | |||||
| { | |||||
| ModelPath = modelPath; | |||||
| } | |||||
| private ModelParams() | |||||
| { | |||||
| // This constructor (default parameterless constructor) is used by Newtonsoft to deserialize! | |||||
| ModelPath = ""; | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="modelPath">The model path.</param> | |||||
| /// <param name="contextSize">Model context size (n_ctx)</param> | |||||
| /// <param name="gpuLayerCount">Number of layers to run in VRAM / GPU memory (n_gpu_layers)</param> | |||||
| /// <param name="seed">Seed for the random number generator (seed)</param> | |||||
| /// <param name="useFp16Memory">Whether to use f16 instead of f32 for memory kv (memory_f16)</param> | |||||
| /// <param name="useMemorymap">Whether to use mmap for faster loads (use_mmap)</param> | |||||
| /// <param name="useMemoryLock">Whether to use mlock to keep model in memory (use_mlock)</param> | |||||
| /// <param name="perplexity">Thether to compute perplexity over the prompt (perplexity)</param> | |||||
| /// <param name="loraAdapter">Lora adapter path (lora_adapter)</param> | |||||
| /// <param name="loraBase">Base model path for the lora adapter (lora_base)</param> | |||||
| /// <param name="threads">Number of threads (-1 = autodetect) (n_threads)</param> | |||||
| /// <param name="batchSize">Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)</param> | |||||
| /// <param name="convertEosToNewLine">Whether to convert eos to newline during the inference.</param> | |||||
| /// <param name="embeddingMode">Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.</param> | |||||
| /// <param name="ropeFrequencyBase">RoPE base frequency.</param> | |||||
| /// <param name="ropeFrequencyScale">RoPE frequency scaling factor</param> | |||||
| /// <param name="mulMatQ">Use experimental mul_mat_q kernels</param> | |||||
| /// <param name="encoding">The encoding to use to convert text for the model</param> | |||||
| [Obsolete("Use object initializer to set all optional parameters")] | |||||
| public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, | |||||
| int seed = 1337, bool useFp16Memory = true, | |||||
| bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | |||||
| string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, | |||||
| bool convertEosToNewLine = false, bool embeddingMode = false, | |||||
| float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, | |||||
| string encoding = "UTF-8") | |||||
| { | { | ||||
| ContextSize = contextSize; | ContextSize = contextSize; | ||||
| GpuLayerCount = gpuLayerCount; | GpuLayerCount = gpuLayerCount; | ||||
| @@ -154,11 +170,27 @@ namespace LLama.Common | |||||
| BatchSize = batchSize; | BatchSize = batchSize; | ||||
| ConvertEosToNewLine = convertEosToNewLine; | ConvertEosToNewLine = convertEosToNewLine; | ||||
| EmbeddingMode = embeddingMode; | EmbeddingMode = embeddingMode; | ||||
| GroupedQueryAttention = gqa; | |||||
| RmsNormEpsilon = rmsNormEps; | |||||
| RopeFrequencyBase = rope_freq_base; | |||||
| RopeFrequencyScale = rope_freq_scale; | |||||
| MulMatQ = muMatQ; | |||||
| } | |||||
| RopeFrequencyBase = ropeFrequencyBase; | |||||
| RopeFrequencyScale = ropeFrequencyScale; | |||||
| MulMatQ = mulMatQ; | |||||
| Encoding = Encoding.GetEncoding(encoding); | |||||
| } | |||||
| } | |||||
| internal class EncodingConverter | |||||
| : JsonConverter<Encoding> | |||||
| { | |||||
| public override Encoding? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||||
| { | |||||
| var name = reader.GetString(); | |||||
| if (name == null) | |||||
| return null; | |||||
| return Encoding.GetEncoding(name); | |||||
| } | |||||
| public override void Write(Utf8JsonWriter writer, Encoding value, JsonSerializerOptions options) | |||||
| { | |||||
| writer.WriteStringValue(value.WebName); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,125 @@ | |||||
| using System; | |||||
| namespace LLama.Exceptions; | |||||
| /// <summary> | |||||
| /// Base class for all grammar exceptions | |||||
| /// </summary> | |||||
| public abstract class GrammarFormatException | |||||
| : Exception | |||||
| { | |||||
| internal GrammarFormatException(string message) | |||||
| : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// An incorrect number of characters were encountered while parsing a hex literal | |||||
| /// </summary> | |||||
| public class GrammarUnexpectedHexCharsCount | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarUnexpectedHexCharsCount(int size, string source) | |||||
| : base($"Expecting {size} hex chars at {source}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Failed to parse a "name" element when one was expected | |||||
| /// </summary> | |||||
| public class GrammarExpectedName | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarExpectedName(string source) | |||||
| : base($"Expecting name at {source}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// An unexpected character was encountered after an escape sequence | |||||
| /// </summary> | |||||
| public class GrammarUnknownEscapeCharacter | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarUnknownEscapeCharacter(string source) | |||||
| : base($"Unknown escape at {source}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// End-of-file was encountered while parsing | |||||
| /// </summary> | |||||
| public class GrammarUnexpectedEndOfInput | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarUnexpectedEndOfInput() | |||||
| : base($"Unexpected end of input") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// A specified string was expected when parsing | |||||
| /// </summary> | |||||
| public class GrammarExpectedNext | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarExpectedNext(string expected, string source) | |||||
| : base($"Expected '{expected}' at {source}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// A specified character was expected to preceded another when parsing | |||||
| /// </summary> | |||||
| public class GrammarExpectedPrevious | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarExpectedPrevious(string expected, string source) | |||||
| : base($"Expecting preceding item to be '{expected}' at {source}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// A CHAR_ALT was created without a preceding CHAR element | |||||
| /// </summary> | |||||
| public class GrammarUnexpectedCharAltElement | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarUnexpectedCharAltElement(string ruleId, int index) | |||||
| : base($"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{index}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// A CHAR_RNG was created without a preceding CHAR element | |||||
| /// </summary> | |||||
| public class GrammarUnexpectedCharRngElement | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarUnexpectedCharRngElement(string ruleId, int index) | |||||
| : base($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{index}") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// An END was encountered before the last element | |||||
| /// </summary> | |||||
| public class GrammarUnexpectedEndElement | |||||
| : GrammarFormatException | |||||
| { | |||||
| internal GrammarUnexpectedEndElement(string ruleId, int index) | |||||
| : base($"Unexpected LLamaGrammarElementType.END: {ruleId},{index}") | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -2,14 +2,16 @@ | |||||
| namespace LLama.Exceptions | namespace LLama.Exceptions | ||||
| { | { | ||||
| public class RuntimeError: Exception | |||||
| public class RuntimeError | |||||
| : Exception | |||||
| { | { | ||||
| public RuntimeError() | public RuntimeError() | ||||
| { | { | ||||
| } | } | ||||
| public RuntimeError(string message): base(message) | |||||
| public RuntimeError(string message) | |||||
| : base(message) | |||||
| { | { | ||||
| } | } | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System.Collections.Generic; | |||||
| namespace LLama.Extensions | |||||
| { | |||||
| internal static class DictionaryExtensions | |||||
| { | |||||
| #if NETSTANDARD2_0 | |||||
| public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue) | |||||
| { | |||||
| return dictionary.TryGetValue(key, out var value) ? value : defaultValue; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,21 @@ | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| namespace LLama.Extensions | |||||
| { | |||||
| internal static class IEnumerableExtensions | |||||
| { | |||||
| #if NETSTANDARD2_0 | |||||
| public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> source, int count) | |||||
| { | |||||
| var list = source.ToList(); | |||||
| if (count >= list.Count) | |||||
| return list; | |||||
| list.RemoveRange(0, list.Count - count); | |||||
| return list; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| @@ -39,8 +39,6 @@ namespace LLama.Extensions | |||||
| result.logits_all = @params.Perplexity; | result.logits_all = @params.Perplexity; | ||||
| result.embedding = @params.EmbeddingMode; | result.embedding = @params.EmbeddingMode; | ||||
| result.low_vram = @params.LowVram; | result.low_vram = @params.LowVram; | ||||
| result.n_gqa = @params.GroupedQueryAttention; | |||||
| result.rms_norm_eps = @params.RmsNormEpsilon; | |||||
| result.rope_freq_base = @params.RopeFrequencyBase; | result.rope_freq_base = @params.RopeFrequencyBase; | ||||
| result.rope_freq_scale = @params.RopeFrequencyScale; | result.rope_freq_scale = @params.RopeFrequencyScale; | ||||
| result.mul_mat_q = @params.MulMatQ; | result.mul_mat_q = @params.MulMatQ; | ||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| namespace LLama.Extensions | |||||
| { | |||||
| internal static class IReadOnlyListExtensions | |||||
| { | |||||
| public static int? IndexOf<T>(this IReadOnlyList<T> list, T item) | |||||
| where T : IEquatable<T> | |||||
| { | |||||
| for (var i = 0; i < list.Count; i++) | |||||
| { | |||||
| if (list[i].Equals(item)) | |||||
| return i; | |||||
| } | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| namespace LLama.Extensions | |||||
| { | |||||
| internal static class ListExtensions | |||||
| { | |||||
| public static void AddRangeSpan<T>(this List<T> list, ReadOnlySpan<T> span) | |||||
| { | |||||
| for (var i = 0; i < span.Length; i++) | |||||
| list.Add(span[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| // This file is used by Code Analysis to maintain SuppressMessage | |||||
| // attributes that are applied to this project. | |||||
| // Project-level suppressions either have no target or are given | |||||
| // a specific target and scoped to a namespace, type, member, etc. | |||||
| using System.Diagnostics.CodeAnalysis; | |||||
| [assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] | |||||
| [assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")] | |||||
| @@ -0,0 +1,406 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Grammars | |||||
| { | |||||
| /// <summary> | |||||
| /// Source: | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp | |||||
| /// | |||||
| /// The commit hash from URL is the actual commit hash that reflects current C# code. | |||||
| /// </summary> | |||||
| internal sealed class GBNFGrammarParser | |||||
| { | |||||
| // NOTE: assumes valid utf8 (but checks for overrun) | |||||
| // copied from llama.cpp | |||||
| private uint DecodeUTF8(ref ReadOnlySpan<byte> src) | |||||
| { | |||||
| int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | |||||
| byte firstByte = src[0]; | |||||
| byte highbits = (byte)(firstByte >> 4); | |||||
| int len = lookup[highbits]; | |||||
| byte mask = (byte)((1 << (8 - len)) - 1); | |||||
| uint value = (uint)(firstByte & mask); | |||||
| int end = len; | |||||
| int pos = 1; | |||||
| for (; pos < end && pos < src.Length; pos++) | |||||
| { | |||||
| value = (uint)((value << 6) + (src[pos] & 0x3F)); | |||||
| } | |||||
| src = src.Slice(pos); | |||||
| return value; | |||||
| } | |||||
| private uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len) | |||||
| { | |||||
| uint nextId = (uint)state.SymbolIds.Count; | |||||
| string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); | |||||
| if (state.SymbolIds.TryGetValue(key, out uint existingId)) | |||||
| { | |||||
| return existingId; | |||||
| } | |||||
| else | |||||
| { | |||||
| state.SymbolIds[key] = nextId; | |||||
| return nextId; | |||||
| } | |||||
| } | |||||
| private uint GenerateSymbolId(ParseState state, string baseName) | |||||
| { | |||||
| uint nextId = (uint)state.SymbolIds.Count; | |||||
| string key = $"{baseName}_{nextId}"; | |||||
| state.SymbolIds[key] = nextId; | |||||
| return nextId; | |||||
| } | |||||
| private void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule) | |||||
| { | |||||
| while (state.Rules.Count <= ruleId) | |||||
| { | |||||
| state.Rules.Add(new List<LLamaGrammarElement>()); | |||||
| } | |||||
| state.Rules[(int)ruleId] = rule; | |||||
| } | |||||
| private bool IsWordChar(byte c) | |||||
| { | |||||
| return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); | |||||
| } | |||||
| private uint ParseHex(ref ReadOnlySpan<byte> src, int size) | |||||
| { | |||||
| int pos = 0; | |||||
| int end = size; | |||||
| uint value = 0; | |||||
| for (; pos < end && pos < src.Length; pos++) | |||||
| { | |||||
| value <<= 4; | |||||
| byte c = src[pos]; | |||||
| if ('a' <= c && c <= 'f') | |||||
| { | |||||
| value += (uint)(c - 'a' + 10); | |||||
| } | |||||
| else if ('A' <= c && c <= 'F') | |||||
| { | |||||
| value += (uint)(c - 'A' + 10); | |||||
| } | |||||
| else if ('0' <= c && c <= '9') | |||||
| { | |||||
| value += (uint)(c - '0'); | |||||
| } | |||||
| else | |||||
| { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (pos != end) | |||||
| { | |||||
| throw new GrammarUnexpectedHexCharsCount(size, Encoding.UTF8.GetString(src.ToArray())); | |||||
| } | |||||
| src = src.Slice(pos); | |||||
| return value; | |||||
| } | |||||
| private ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) | |||||
| { | |||||
| int pos = 0; | |||||
| while (pos < src.Length && | |||||
| (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' || | |||||
| (newlineOk && (src[pos] == '\r' || src[pos] == '\n')))) | |||||
| { | |||||
| if (src[pos] == '#') | |||||
| { | |||||
| while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n') | |||||
| { | |||||
| pos++; | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| pos++; | |||||
| } | |||||
| } | |||||
| return src.Slice(pos); | |||||
| } | |||||
| private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src) | |||||
| { | |||||
| int pos = 0; | |||||
| while (pos < src.Length && IsWordChar(src[pos])) | |||||
| { | |||||
| pos++; | |||||
| } | |||||
| if (pos == 0) | |||||
| { | |||||
| throw new GrammarExpectedName(Encoding.UTF8.GetString(src.ToArray())); | |||||
| } | |||||
| return src.Slice(pos); | |||||
| } | |||||
| private uint ParseChar(ref ReadOnlySpan<byte> src) | |||||
| { | |||||
| if (src[0] == '\\') | |||||
| { | |||||
| var chr = src[1]; | |||||
| src = src.Slice(2); | |||||
| switch (chr) | |||||
| { | |||||
| case (byte)'x': | |||||
| return ParseHex(ref src, 2); | |||||
| case (byte)'u': | |||||
| return ParseHex(ref src, 4); | |||||
| case (byte)'U': | |||||
| return ParseHex(ref src, 8); | |||||
| case (byte)'t': | |||||
| return '\t'; | |||||
| case (byte)'r': | |||||
| return '\r'; | |||||
| case (byte)'n': | |||||
| return '\n'; | |||||
| case (byte)'\\': | |||||
| case (byte)'"': | |||||
| case (byte)'[': | |||||
| case (byte)']': | |||||
| return chr; | |||||
| default: | |||||
| throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())); | |||||
| } | |||||
| } | |||||
| else if (!src.IsEmpty) | |||||
| { | |||||
| return DecodeUTF8(ref src); | |||||
| } | |||||
| throw new GrammarUnexpectedEndOfInput(); | |||||
| } | |||||
| private ReadOnlySpan<byte> ParseSequence( | |||||
| ParseState state, | |||||
| ReadOnlySpan<byte> pos, | |||||
| string ruleName, | |||||
| List<LLamaGrammarElement> outElements, | |||||
| bool isNested) | |||||
| { | |||||
| int lastSymStart = outElements.Count; | |||||
| while (!pos.IsEmpty) | |||||
| { | |||||
| if (pos[0] == '"') // literal string | |||||
| { | |||||
| pos = pos.Slice(1); | |||||
| lastSymStart = outElements.Count; | |||||
| while (!pos.IsEmpty && pos[0] != '"') | |||||
| { | |||||
| var charPair = ParseChar(ref pos); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair)); | |||||
| } | |||||
| pos = ParseSpace(pos.Slice(1), isNested); | |||||
| } | |||||
| else if (pos[0] == '[') // char range(s) | |||||
| { | |||||
| pos = pos.Slice(1); | |||||
| var startType = LLamaGrammarElementType.CHAR; | |||||
| if (pos[0] == '^') | |||||
| { | |||||
| pos = pos.Slice(1); | |||||
| startType = LLamaGrammarElementType.CHAR_NOT; | |||||
| } | |||||
| lastSymStart = outElements.Count; | |||||
| while (!pos.IsEmpty && pos[0] != ']') | |||||
| { | |||||
| var charPair = ParseChar(ref pos); | |||||
| var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | |||||
| outElements.Add(new LLamaGrammarElement(type, charPair)); | |||||
| if (pos[0] == '-' && pos[1] != ']') | |||||
| { | |||||
| pos = pos.Slice(1); | |||||
| var endCharPair = ParseChar(ref pos); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair)); | |||||
| } | |||||
| } | |||||
| pos = ParseSpace(pos.Slice(1), isNested); | |||||
| } | |||||
| else if (IsWordChar(pos[0])) // rule reference | |||||
| { | |||||
| var nameEnd = ParseName(pos); | |||||
| uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); | |||||
| pos = ParseSpace(nameEnd, isNested); | |||||
| lastSymStart = outElements.Count; | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); | |||||
| } | |||||
| else if (pos[0] == '(') // grouping | |||||
| { | |||||
| // parse nested alternates into synthesized rule | |||||
| pos = ParseSpace(pos.Slice(1), true); | |||||
| uint subRuleId = GenerateSymbolId(state, ruleName); | |||||
| pos = ParseAlternates(state, pos, ruleName, subRuleId, true); | |||||
| lastSymStart = outElements.Count; | |||||
| // output reference to synthesized rule | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||||
| if (pos[0] != ')') | |||||
| throw new GrammarExpectedNext(")", Encoding.UTF8.GetString(pos.ToArray())); | |||||
| pos = ParseSpace(pos.Slice(1), isNested); | |||||
| } | |||||
| else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator | |||||
| { | |||||
| if (lastSymStart == outElements.Count) | |||||
| throw new GrammarExpectedPrevious("*/+/?", Encoding.UTF8.GetString(pos.ToArray())); | |||||
| // apply transformation to previous symbol (lastSymStart to end) according to | |||||
| // rewrite rules: | |||||
| // S* --> S' ::= S S' | | |||||
| // S+ --> S' ::= S S' | S | |||||
| // S? --> S' ::= S | | |||||
| uint subRuleId = GenerateSymbolId(state, ruleName); | |||||
| List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>(); | |||||
| // add preceding symbol to generated rule | |||||
| subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); | |||||
| if (pos[0] == '*' || pos[0] == '+') | |||||
| { | |||||
| // cause generated rule to recurse | |||||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||||
| } | |||||
| // mark start of alternate def | |||||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); | |||||
| if (pos[0] == '+') | |||||
| { | |||||
| // add preceding symbol as alternate only for '+' (otherwise empty) | |||||
| subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); | |||||
| } | |||||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||||
| AddRule(state, subRuleId, subRule); | |||||
| // in original rule, replace previous symbol with reference to generated rule | |||||
| outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||||
| pos = ParseSpace(pos.Slice(1), isNested); | |||||
| } | |||||
| else | |||||
| { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return pos; | |||||
| } | |||||
| private ReadOnlySpan<byte> ParseAlternates( | |||||
| ParseState state, | |||||
| ReadOnlySpan<byte> src, | |||||
| string ruleName, | |||||
| uint ruleId, | |||||
| bool isNested) | |||||
| { | |||||
| var rule = new List<LLamaGrammarElement>(); | |||||
| ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested); | |||||
| while (!pos.IsEmpty && pos[0] == '|') | |||||
| { | |||||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); | |||||
| pos = ParseSpace(pos.Slice(1), true); | |||||
| pos = ParseSequence(state, pos, ruleName, rule, isNested); | |||||
| } | |||||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||||
| AddRule(state, ruleId, rule); | |||||
| return pos; | |||||
| } | |||||
| private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src) | |||||
| { | |||||
| ReadOnlySpan<byte> nameEnd = ParseName(src); | |||||
| ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false); | |||||
| int nameLen = src.Length - nameEnd.Length; | |||||
| uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); | |||||
| string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); | |||||
| if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) | |||||
| throw new GrammarExpectedNext("::=", Encoding.UTF8.GetString(pos.ToArray())); | |||||
| pos = ParseSpace(pos.Slice(3), true); | |||||
| pos = ParseAlternates(state, pos, name, ruleId, false); | |||||
| if (!pos.IsEmpty && pos[0] == '\r') | |||||
| { | |||||
| pos = pos.Slice(pos[1] == '\n' ? 2 : 1); | |||||
| } | |||||
| else if (!pos.IsEmpty && pos[0] == '\n') | |||||
| { | |||||
| pos = pos.Slice(1); | |||||
| } | |||||
| else if (!pos.IsEmpty) | |||||
| { | |||||
| throw new GrammarExpectedNext("newline or EOF", Encoding.UTF8.GetString(pos.ToArray())); | |||||
| } | |||||
| return ParseSpace(pos, true); | |||||
| } | |||||
| /// <summary> | |||||
| /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> | |||||
| /// </summary> | |||||
| /// <param name="input">The string to parse</param> | |||||
| /// <param name="startRule">The name of the root rule of this grammar</param> | |||||
| /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception> | |||||
| /// <returns>A ParseState that can be converted into a grammar for sampling</returns> | |||||
| public Grammar Parse(string input, string startRule) | |||||
| { | |||||
| var byteArray = Encoding.UTF8.GetBytes(input); | |||||
| var state = new ParseState(); | |||||
| var pos = ParseSpace(byteArray, true); | |||||
| while (!pos.IsEmpty) | |||||
| { | |||||
| pos = ParseRule(state, pos); | |||||
| } | |||||
| var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key); | |||||
| var rules = new List<GrammarRule>(); | |||||
| for (var i = 0; i < state.Rules.Count; i++) | |||||
| { | |||||
| var elements = state.Rules[i]; | |||||
| var name = names[(uint)i]; | |||||
| rules.Add(new GrammarRule(name, elements)); | |||||
| } | |||||
| var startRuleIndex = state.SymbolIds[startRule]; | |||||
| return new Grammar(rules, startRuleIndex); | |||||
| } | |||||
| private record ParseState | |||||
| { | |||||
| public SortedDictionary<string, uint> SymbolIds { get; } = new(); | |||||
| public List<List<LLamaGrammarElement>> Rules { get; } = new(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,148 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Grammars | |||||
| { | |||||
| /// <summary> | |||||
| /// A grammar is a set of <see cref="GrammarRule"/>s for deciding which characters are valid next. Can be used to constrain | |||||
| /// output to certain formats - e.g. force the model to output JSON | |||||
| /// </summary> | |||||
| public sealed class Grammar | |||||
| { | |||||
| /// <summary> | |||||
| /// Index of the initial rule to start from | |||||
| /// </summary> | |||||
| public ulong StartRuleIndex { get; set; } | |||||
| /// <summary> | |||||
| /// The rules which make up this grammar | |||||
| /// </summary> | |||||
| public IReadOnlyList<GrammarRule> Rules { get; } | |||||
| /// <summary> | |||||
| /// Create a new grammar from a set of rules | |||||
| /// </summary> | |||||
| /// <param name="rules">The rules which make up this grammar</param> | |||||
| /// <param name="startRuleIndex">Index of the initial rule to start from</param> | |||||
| /// <exception cref="ArgumentOutOfRangeException"></exception> | |||||
| public Grammar(IReadOnlyList<GrammarRule> rules, ulong startRuleIndex) | |||||
| { | |||||
| if (startRuleIndex >= (uint)rules.Count) | |||||
| throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules"); | |||||
| StartRuleIndex = startRuleIndex; | |||||
| Rules = rules; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a `SafeLLamaGrammarHandle` instance to use for parsing | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public SafeLLamaGrammarHandle CreateInstance() | |||||
| { | |||||
| return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex); | |||||
| } | |||||
| /// <summary> | |||||
| /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> into a Grammar | |||||
| /// </summary> | |||||
| /// <param name="gbnf">The string to parse</param> | |||||
| /// <param name="startRule">Name of the start rule of this grammar</param> | |||||
| /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception> | |||||
| /// <returns>A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling</returns> | |||||
| public static Grammar Parse(string gbnf, string startRule) | |||||
| { | |||||
| var parser = new GBNFGrammarParser(); | |||||
| return parser.Parse(gbnf, startRule); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override string ToString() | |||||
| { | |||||
| var builder = new StringBuilder(); | |||||
| PrintGrammar(builder); | |||||
| return builder.ToString(); | |||||
| } | |||||
| private void PrintGrammar(StringBuilder output) | |||||
| { | |||||
| for (var i = 0; i < Rules.Count; i++) | |||||
| PrintRule(output, Rules[i]); | |||||
| } | |||||
| private void PrintRule(StringBuilder output, GrammarRule rule) | |||||
| { | |||||
| output.Append($"{rule.Name} ::= "); | |||||
| for (int i = 0, end = rule.Elements.Count - 1; i < end; i++) | |||||
| { | |||||
| var elem = rule.Elements[i]; | |||||
| switch (elem.Type) | |||||
| { | |||||
| // GrammarRule has already verified that END is not being misused, no need to check again | |||||
| case LLamaGrammarElementType.END: | |||||
| break; | |||||
| case LLamaGrammarElementType.ALT: | |||||
| output.Append("| "); | |||||
| break; | |||||
| case LLamaGrammarElementType.RULE_REF: | |||||
| output.Append($"{Rules[(int)elem.Value].Name} "); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| output.Append('['); | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| output.Append("[^"); | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| output.Append('-'); | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| } | |||||
| if (elem.IsCharElement()) | |||||
| { | |||||
| switch (rule.Elements[i + 1].Type) | |||||
| { | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| break; | |||||
| default: | |||||
| output.Append("] "); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| output.AppendLine(); | |||||
| } | |||||
| private static void PrintGrammarChar(StringBuilder output, uint c) | |||||
| { | |||||
| if (c >= 0x20 && c <= 0x7F) | |||||
| { | |||||
| output.Append((char)c); | |||||
| } | |||||
| else | |||||
| { | |||||
| // cop out of encoding UTF-8 | |||||
| output.Append($"<U+{c:X4}>"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,74 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Grammars | |||||
| { | |||||
| /// <summary> | |||||
| /// A single rule in a <see cref="Grammar"/> | |||||
| /// </summary> | |||||
| public sealed record GrammarRule | |||||
| { | |||||
| /// <summary> | |||||
| /// Name of this rule | |||||
| /// </summary> | |||||
| public string Name { get; } | |||||
| /// <summary> | |||||
| /// The elements of this grammar rule | |||||
| /// </summary> | |||||
| public IReadOnlyList<LLamaGrammarElement> Elements { get; } | |||||
| /// <summary> | |||||
| /// Create a new GrammarRule containing the given elements | |||||
| /// </summary> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="elements"></param> | |||||
| /// <exception cref="ArgumentException"></exception> | |||||
| public GrammarRule(string name, IReadOnlyList<LLamaGrammarElement> elements) | |||||
| { | |||||
| Validate(elements, name); | |||||
| Name = name; | |||||
| Elements = elements; | |||||
| } | |||||
| private static void Validate(IReadOnlyList<LLamaGrammarElement> elements, string name) | |||||
| { | |||||
| if (elements.Count == 0) | |||||
| throw new ArgumentException("Cannot create a GrammarRule with zero elements", nameof(elements)); | |||||
| if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END) | |||||
| throw new ArgumentException("Last grammar element must be END", nameof(elements)); | |||||
| for (var i = 0; i < elements.Count; i++) | |||||
| { | |||||
| switch (elements[i].Type) | |||||
| { | |||||
| case LLamaGrammarElementType.END: | |||||
| if (i != elements.Count - 1) | |||||
| throw new GrammarUnexpectedEndElement(name, i); | |||||
| continue; | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| if (i == 0 || !elements[i - 1].IsCharElement()) | |||||
| throw new GrammarUnexpectedCharRngElement(name, i); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| if (i == 0 || !elements[i - 1].IsCharElement()) | |||||
| throw new GrammarUnexpectedCharAltElement(name, i); | |||||
| break; | |||||
| case LLamaGrammarElementType.ALT: | |||||
| case LLamaGrammarElementType.RULE_REF: | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| break; | |||||
| default: | |||||
| throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'", nameof(elements)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -9,34 +9,48 @@ using System.IO.MemoryMappedFiles; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using Microsoft.Win32.SafeHandles; | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| /// <summary> | /// <summary> | ||||
| /// The abstraction of a LLama model, which holds the context in the native library. | |||||
| /// A llama_context, which holds all the context required to interact with a model | |||||
| /// </summary> | /// </summary> | ||||
| public class LLamaModel: IDisposable | |||||
| public sealed class LLamaContext | |||||
| : IDisposable | |||||
| { | { | ||||
| // TODO: expose more properties. | |||||
| ILLamaLogger? _logger; | |||||
| Encoding _encoding; | |||||
| SafeLLamaContextHandle _ctx; | |||||
| private readonly ILLamaLogger? _logger; | |||||
| private readonly Encoding _encoding; | |||||
| private readonly SafeLLamaContextHandle _ctx; | |||||
| /// <summary> | |||||
| /// Total number of tokens in vocabulary of this model | |||||
| /// </summary> | |||||
| public int VocabCount => _ctx.VocabCount; | |||||
| /// <summary> | |||||
| /// Total number of tokens in the context | |||||
| /// </summary> | |||||
| public int ContextSize => _ctx.ContextSize; | |||||
| /// <summary> | /// <summary> | ||||
| /// The context size. | |||||
| /// Dimension of embedding vectors | |||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize { get; } | |||||
| public int EmbeddingSize => _ctx.EmbeddingSize; | |||||
| /// <summary> | /// <summary> | ||||
| /// The model params set for this model. | /// The model params set for this model. | ||||
| /// </summary> | /// </summary> | ||||
| public IModelParams Params { get; set; } | public IModelParams Params { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// The native handle, which is used to be passed to the native APIs. Please avoid using it | |||||
| /// unless you know what is the usage of the Native API. | |||||
| /// The native handle, which is used to be passed to the native APIs | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>Be careful how you use this!</remarks> | |||||
| public SafeLLamaContextHandle NativeHandle => _ctx; | public SafeLLamaContextHandle NativeHandle => _ctx; | ||||
| /// <summary> | /// <summary> | ||||
| /// The encoding set for this model to deal with text input. | /// The encoding set for this model to deal with text input. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -59,17 +73,59 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="Params">Model params.</param> | |||||
| /// <param name="encoding">Encoding to deal with text input.</param> | |||||
| /// <param name="params">Model params.</param> | |||||
| /// <param name="logger">The logger.</param> | /// <param name="logger">The logger.</param> | ||||
| public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null) | |||||
| [Obsolete("Use the LLamaWeights.CreateContext instead")] | |||||
| public LLamaContext(IModelParams @params, ILLamaLogger? logger = null) | |||||
| { | |||||
| Params = @params; | |||||
| _logger = logger; | |||||
| _encoding = @params.Encoding; | |||||
| _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); | |||||
| _ctx = Utils.InitLLamaContextFromModelParams(Params); | |||||
| } | |||||
| internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILLamaLogger? logger = null) | |||||
| { | { | ||||
| Params = @params; | |||||
| _logger = logger; | _logger = logger; | ||||
| this.Params = Params; | |||||
| _encoding = Encoding.GetEncoding(encoding); | |||||
| _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); | |||||
| _ctx = Utils.InitLLamaContextFromModelParams(this.Params); | |||||
| ContextSize = NativeApi.llama_n_ctx(_ctx); | |||||
| _encoding = @params.Encoding; | |||||
| _ctx = nativeContext; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a new LLamaContext for the given LLamaWeights | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <param name="logger"></param> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| public LLamaContext(LLamaWeights model, IModelParams @params, ILLamaLogger? logger = null) | |||||
| { | |||||
| if (model.NativeHandle.IsClosed) | |||||
| throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); | |||||
| Params = @params; | |||||
| _logger = logger; | |||||
| _encoding = @params.Encoding; | |||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||||
| _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a copy of the current state of this context | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public LLamaContext Clone() | |||||
| { | |||||
| using var pin = Params.ToLlamaContextParams(out var lparams); | |||||
| var clone = _ctx.Clone(lparams); | |||||
| return new LLamaContext(clone, Params); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -90,9 +146,10 @@ namespace LLama | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public string DeTokenize(IEnumerable<llama_token> tokens) | public string DeTokenize(IEnumerable<llama_token> tokens) | ||||
| { | { | ||||
| StringBuilder sb = new(); | |||||
| foreach(var token in tokens) | |||||
| sb.Append(_ctx.TokenToString(token, _encoding)); | |||||
| var sb = new StringBuilder(); | |||||
| foreach (var token in tokens) | |||||
| _ctx.TokenToString(token, _encoding, sb); | |||||
| return sb.ToString(); | return sb.ToString(); | ||||
| } | } | ||||
| @@ -147,7 +204,7 @@ namespace LLama | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public State GetState() | public State GetState() | ||||
| { | { | ||||
| var stateSize = NativeApi.llama_get_state_size(_ctx); | |||||
| var stateSize = _ctx.GetStateSize(); | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| @@ -156,15 +213,17 @@ namespace LLama | |||||
| try | try | ||||
| { | { | ||||
| // Copy the state data into "big memory", discover the actual size required | // Copy the state data into "big memory", discover the actual size required | ||||
| var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); | |||||
| var actualSize = _ctx.GetState(bigMemory, stateSize); | |||||
| // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying | |||||
| if (actualSize >= stateSize - 1_000_000) | |||||
| return new State(bigMemory); | |||||
| // Allocate a smaller buffer | |||||
| // Allocate a smaller buffer which is exactly the right size | |||||
| smallMemory = Marshal.AllocHGlobal((nint)actualSize); | smallMemory = Marshal.AllocHGlobal((nint)actualSize); | ||||
| // Copy into the smaller buffer and free the large one to save excess memory usage | // Copy into the smaller buffer and free the large one to save excess memory usage | ||||
| Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); | Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); | ||||
| Marshal.FreeHGlobal(bigMemory); | |||||
| bigMemory = IntPtr.Zero; | |||||
| return new State(smallMemory); | return new State(smallMemory); | ||||
| } | } | ||||
| @@ -224,7 +283,7 @@ namespace LLama | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); | |||||
| _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -241,11 +300,19 @@ namespace LLama | |||||
| /// <param name="topP"></param> | /// <param name="topP"></param> | ||||
| /// <param name="tfsZ"></param> | /// <param name="tfsZ"></param> | ||||
| /// <param name="typicalP"></param> | /// <param name="typicalP"></param> | ||||
| /// <param name="grammar"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | ||||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) | |||||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, | |||||
| SafeLLamaGrammarHandle? grammar = null) | |||||
| { | { | ||||
| llama_token id; | llama_token id; | ||||
| if (grammar != null) | |||||
| { | |||||
| SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); | |||||
| } | |||||
| if (temperature <= 0) | if (temperature <= 0) | ||||
| { | { | ||||
| // Greedy sampling | // Greedy sampling | ||||
| @@ -279,6 +346,12 @@ namespace LLama | |||||
| } | } | ||||
| mirostat_mu = mu; | mirostat_mu = mu; | ||||
| } | } | ||||
| if (grammar != null) | |||||
| { | |||||
| NativeApi.llama_grammar_accept_token(_ctx, grammar, id); | |||||
| } | |||||
| return id; | return id; | ||||
| } | } | ||||
| @@ -297,41 +370,89 @@ namespace LLama | |||||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | ||||
| bool penalizeNL = true) | bool penalizeNL = true) | ||||
| { | { | ||||
| var n_vocab = _ctx.VocabCount; | |||||
| var logits = _ctx.GetLogits(); | var logits = _ctx.GetLogits(); | ||||
| // Apply params.logit_bias map | // Apply params.logit_bias map | ||||
| if(logitBias is not null) | |||||
| if (logitBias is not null) | |||||
| { | { | ||||
| foreach (var (key, value) in logitBias) | foreach (var (key, value) in logitBias) | ||||
| { | |||||
| logits[key] += value; | logits[key] += value; | ||||
| } | |||||
| } | } | ||||
| var candidates = new LLamaTokenData[n_vocab]; | |||||
| for (llama_token token_id = 0; token_id < n_vocab; token_id++) | |||||
| candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); | |||||
| // Apply penalties | |||||
| float nl_logit = logits[NativeApi.llama_token_nl()]; | |||||
| int lastTokensCount = lastTokens.Count(); | |||||
| var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize); | |||||
| SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, | |||||
| lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), | |||||
| (ulong)last_n_repeat, repeatPenalty); | |||||
| SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, | |||||
| lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), | |||||
| (ulong)last_n_repeat, alphaFrequency, alphaPresence); | |||||
| // Save the newline logit value | |||||
| var nl_token = NativeApi.llama_token_nl(_ctx); | |||||
| var nl_logit = logits[nl_token]; | |||||
| // Convert logits into token candidates | |||||
| var candidates_p = LLamaTokenDataArray.Create(logits); | |||||
| // Extract most recently returned tokens | |||||
| var last_n_repeat = Math.Min(ContextSize, repeatLastTokensCount); | |||||
| var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); | |||||
| // Apply penalties to candidates | |||||
| SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty); | |||||
| SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence); | |||||
| // Restore newline token logit value if necessary | |||||
| if (!penalizeNL) | if (!penalizeNL) | ||||
| { | { | ||||
| logits[NativeApi.llama_token_nl()] = nl_logit; | |||||
| var candidatesSpan = candidates_p.data.Span; | |||||
| for (var i = 0; i < candidates_p.data.Length; i++) | |||||
| { | |||||
| ref var item = ref candidatesSpan[i]; | |||||
| if (item.id == nl_token) | |||||
| item.logit = nl_logit; | |||||
| } | |||||
| candidates_p.sorted = false; | |||||
| } | } | ||||
| return candidates_p; | return candidates_p; | ||||
| } | } | ||||
| #region eval overloads | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="pastTokensCount"></param> | |||||
| /// <returns>The updated `pastTokensCount`.</returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public int Eval(llama_token[] tokens, llama_token pastTokensCount) | |||||
| { | |||||
| return Eval(tokens.AsSpan(), pastTokensCount); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="pastTokensCount"></param> | |||||
| /// <returns>The updated `pastTokensCount`.</returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public int Eval(List<llama_token> tokens, llama_token pastTokensCount) | |||||
| { | |||||
| #if NET5_0_OR_GREATER | |||||
| var span = CollectionsMarshal.AsSpan(tokens); | |||||
| return Eval(span, pastTokensCount); | |||||
| #else | |||||
| // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of | |||||
| // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | |||||
| // avoid the copying. | |||||
| var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count); | |||||
| try | |||||
| { | |||||
| tokens.CopyTo(rented, 0); | |||||
| return Eval(rented, pastTokensCount); | |||||
| } | |||||
| finally | |||||
| { | |||||
| System.Buffers.ArrayPool<llama_token>.Shared.Return(rented); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -339,20 +460,32 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) | |||||
| public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount) | |||||
| { | { | ||||
| int total = tokens.Length; | |||||
| for(int i = 0; i < total; i += Params.BatchSize) | |||||
| return Eval(tokens.Span, pastTokensCount); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="pastTokensCount"></param> | |||||
| /// <returns>The updated `pastTokensCount`.</returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount) | |||||
| { | |||||
| var total = tokens.Length; | |||||
| for(var i = 0; i < total; i += Params.BatchSize) | |||||
| { | { | ||||
| int n_eval = total - i; | |||||
| if(n_eval > Params.BatchSize) | |||||
| var n_eval = total - i; | |||||
| if (n_eval > Params.BatchSize) | |||||
| { | { | ||||
| n_eval = Params.BatchSize; | n_eval = Params.BatchSize; | ||||
| } | } | ||||
| if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) | |||||
| if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) | |||||
| { | { | ||||
| _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); | |||||
| _logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); | |||||
| throw new RuntimeError("Failed to eval."); | throw new RuntimeError("Failed to eval."); | ||||
| } | } | ||||
| @@ -360,16 +493,26 @@ namespace LLama | |||||
| } | } | ||||
| return pastTokensCount; | return pastTokensCount; | ||||
| } | } | ||||
| #endregion | |||||
| // TODO: add comment | |||||
| internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids) | internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids) | ||||
| { | { | ||||
| foreach(var id in ids) | foreach(var id in ids) | ||||
| yield return _ctx.TokenToString(id, _encoding); | yield return _ctx.TokenToString(id, _encoding); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Convert a token into a string | |||||
| /// </summary> | |||||
| /// <param name="token"></param> | |||||
| /// <returns></returns> | |||||
| public string TokenToString(llama_token token) | |||||
| { | |||||
| return NativeHandle.TokenToString(token, Encoding); | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public virtual void Dispose() | |||||
| public void Dispose() | |||||
| { | { | ||||
| _ctx.Dispose(); | _ctx.Dispose(); | ||||
| } | } | ||||
| @@ -378,12 +521,11 @@ namespace LLama | |||||
| /// The state of this model, which can be reloaded later | /// The state of this model, which can be reloaded later | ||||
| /// </summary> | /// </summary> | ||||
| public class State | public class State | ||||
| : SafeHandleZeroOrMinusOneIsInvalid | |||||
| : SafeLLamaHandleBase | |||||
| { | { | ||||
| internal State(IntPtr memory) | internal State(IntPtr memory) | ||||
| : base(true) | |||||
| : base(memory) | |||||
| { | { | ||||
| SetHandle(memory); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -1,9 +1,6 @@ | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| using System.Linq; | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -11,18 +8,15 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// The embedder for LLama, which supports getting embeddings from text. | /// The embedder for LLama, which supports getting embeddings from text. | ||||
| /// </summary> | /// </summary> | ||||
| public class LLamaEmbedder : IDisposable | |||||
| public sealed class LLamaEmbedder | |||||
| : IDisposable | |||||
| { | { | ||||
| SafeLLamaContextHandle _ctx; | |||||
| private readonly LLamaContext _ctx; | |||||
| /// <summary> | /// <summary> | ||||
| /// Warning: must ensure the original model has params.embedding = true; | |||||
| /// Dimension of embedding vectors | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | |||||
| internal LLamaEmbedder(SafeLLamaContextHandle ctx) | |||||
| { | |||||
| _ctx = ctx; | |||||
| } | |||||
| public int EmbeddingSize => _ctx.EmbeddingSize; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -31,52 +25,67 @@ namespace LLama | |||||
| public LLamaEmbedder(IModelParams @params) | public LLamaEmbedder(IModelParams @params) | ||||
| { | { | ||||
| @params.EmbeddingMode = true; | @params.EmbeddingMode = true; | ||||
| _ctx = Utils.InitLLamaContextFromModelParams(@params); | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| _ctx = weights.CreateContext(@params); | |||||
| } | |||||
| public LLamaEmbedder(LLamaWeights weights, IModelParams @params) | |||||
| { | |||||
| _ctx = weights.CreateContext(@params); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Get the embeddings of the text. | /// Get the embeddings of the text. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="threads">Threads used for inference.</param> | |||||
| /// <param name="threads">unused</param> | |||||
| /// <param name="addBos">Add bos to the text.</param> | /// <param name="addBos">Add bos to the text.</param> | ||||
| /// <param name="encoding"></param> | |||||
| /// <param name="encoding">unused</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") | |||||
| [Obsolete("'threads' and 'encoding' parameters are no longer used")] | |||||
| // ReSharper disable once MethodOverloadWithOptionalParameter | |||||
| public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") | |||||
| { | { | ||||
| if (threads == -1) | |||||
| { | |||||
| threads = Math.Max(Environment.ProcessorCount / 2, 1); | |||||
| } | |||||
| int n_past = 0; | |||||
| if (addBos) | |||||
| { | |||||
| text = text.Insert(0, " "); | |||||
| } | |||||
| return GetEmbeddings(text, addBos); | |||||
| } | |||||
| var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); | |||||
| /// <summary> | |||||
| /// Get the embeddings of the text. | |||||
| /// </summary> | |||||
| /// <param name="text"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public float[] GetEmbeddings(string text) | |||||
| { | |||||
| return GetEmbeddings(text, true); | |||||
| } | |||||
| /// <summary> | |||||
| /// Get the embeddings of the text. | |||||
| /// </summary> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="addBos">Add bos to the text.</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public float[] GetEmbeddings(string text, bool addBos) | |||||
| { | |||||
| var embed_inp_array = _ctx.Tokenize(text, addBos); | |||||
| // TODO(Rinne): deal with log of prompt | // TODO(Rinne): deal with log of prompt | ||||
| if (embed_inp_array.Length > 0) | if (embed_inp_array.Length > 0) | ||||
| { | |||||
| if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0) | |||||
| { | |||||
| throw new RuntimeError("Failed to eval."); | |||||
| } | |||||
| } | |||||
| _ctx.Eval(embed_inp_array, 0); | |||||
| int n_embed = NativeApi.llama_n_embd(_ctx); | |||||
| var embeddings = NativeApi.llama_get_embeddings(_ctx); | |||||
| if (embeddings == null) | |||||
| unsafe | |||||
| { | { | ||||
| return Array.Empty<float>(); | |||||
| var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle); | |||||
| if (embeddings == null) | |||||
| return Array.Empty<float>(); | |||||
| return new Span<float>(embeddings, EmbeddingSize).ToArray(); | |||||
| } | } | ||||
| var span = new Span<float>(embeddings, n_embed); | |||||
| float[] res = new float[n_embed]; | |||||
| span.CopyTo(res.AsSpan()); | |||||
| return res; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -18,10 +18,6 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public abstract class StatefulExecutorBase : ILLamaExecutor | public abstract class StatefulExecutorBase : ILLamaExecutor | ||||
| { | { | ||||
| /// <summary> | |||||
| /// The loaded model for this executor. | |||||
| /// </summary> | |||||
| protected readonly LLamaModel _model; | |||||
| /// <summary> | /// <summary> | ||||
| /// The logger used by this executor. | /// The logger used by this executor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -63,9 +59,9 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| protected FixedSizeQueue<llama_token> _last_n_tokens; | protected FixedSizeQueue<llama_token> _last_n_tokens; | ||||
| /// <summary> | /// <summary> | ||||
| /// The mode used by the executor. | |||||
| /// The context used by the executor. | |||||
| /// </summary> | /// </summary> | ||||
| public LLamaModel Model => _model; | |||||
| public LLamaContext Context { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Current "mu" value for mirostat sampling | /// Current "mu" value for mirostat sampling | ||||
| @@ -75,16 +71,16 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model"></param> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="logger"></param> | /// <param name="logger"></param> | ||||
| protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null) | |||||
| protected StatefulExecutorBase(LLamaContext context, ILLamaLogger? logger = null) | |||||
| { | { | ||||
| _model = model; | |||||
| Context = context; | |||||
| _logger = logger; | _logger = logger; | ||||
| _pastTokensCount = 0; | _pastTokensCount = 0; | ||||
| _consumedTokensCount = 0; | _consumedTokensCount = 0; | ||||
| _n_session_consumed = 0; | _n_session_consumed = 0; | ||||
| _last_n_tokens = new FixedSizeQueue<llama_token>(_model.ContextSize).FillWith(0); | |||||
| _last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -104,9 +100,9 @@ namespace LLama | |||||
| if (File.Exists(filename)) | if (File.Exists(filename)) | ||||
| { | { | ||||
| _logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info); | _logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info); | ||||
| llama_token[] session_tokens = new llama_token[_model.ContextSize]; | |||||
| llama_token[] session_tokens = new llama_token[Context.ContextSize]; | |||||
| ulong n_token_count_out = 0; | ulong n_token_count_out = 0; | ||||
| if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out)) | |||||
| if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out)) | |||||
| { | { | ||||
| _logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error); | _logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error); | ||||
| throw new RuntimeError($"Failed to load session file {_pathSession}"); | throw new RuntimeError($"Failed to load session file {_pathSession}"); | ||||
| @@ -156,7 +152,7 @@ namespace LLama | |||||
| public void SaveSessionFile(string filename) | public void SaveSessionFile(string filename) | ||||
| { | { | ||||
| var session_token_array = _session_tokens.ToArray(); | var session_token_array = _session_tokens.ToArray(); | ||||
| NativeApi.llama_save_session_file(_model.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); | |||||
| NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -173,7 +169,7 @@ namespace LLama | |||||
| _pastTokensCount = Math.Max(1, tokensToKeep); | _pastTokensCount = Math.Max(1, tokensToKeep); | ||||
| // insert n_left/2 tokens at the start of embed from last_n_tokens | // insert n_left/2 tokens at the start of embed from last_n_tokens | ||||
| _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(_model.ContextSize - n_left / 2 - _embeds.Count)); | |||||
| _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(Context.ContextSize - n_left / 2 - _embeds.Count)); | |||||
| // stop saving session if we run out of context | // stop saving session if we run out of context | ||||
| _pathSession = string.Empty; | _pathSession = string.Empty; | ||||
| @@ -270,10 +266,7 @@ namespace LLama | |||||
| public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| cancellationToken.ThrowIfCancellationRequested(); | cancellationToken.ThrowIfCancellationRequested(); | ||||
| if (inferenceParams is null) | |||||
| { | |||||
| inferenceParams = new InferenceParams(); | |||||
| } | |||||
| inferenceParams ??= new InferenceParams(); | |||||
| InferStateArgs args = new InferStateArgs() | InferStateArgs args = new InferStateArgs() | ||||
| { | { | ||||
| @@ -296,7 +289,7 @@ namespace LLama | |||||
| if (args.ReturnValue) | if (args.ReturnValue) | ||||
| { | { | ||||
| foreach (var item in _model.GenerateResult(_embeds)) | |||||
| foreach (var item in Context.GenerateResult(_embeds)) | |||||
| { | { | ||||
| yield return item; | yield return item; | ||||
| } | } | ||||
| @@ -374,7 +367,7 @@ namespace LLama | |||||
| public int MatchingSessionTokensCount { get; set; } | public int MatchingSessionTokensCount { get; set; } | ||||
| [JsonPropertyName("path_session")] | [JsonPropertyName("path_session")] | ||||
| public string SessionFilePath { get; set; } | |||||
| public string? SessionFilePath { get; set; } | |||||
| [JsonPropertyName("embd")] | [JsonPropertyName("embd")] | ||||
| public List<llama_token> Embeds { get; set; } | public List<llama_token> Embeds { get; set; } | ||||
| @@ -5,6 +5,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | |||||
| using System.Text.Json; | using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| @@ -24,14 +25,14 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model"></param> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="instructionPrefix"></param> | /// <param name="instructionPrefix"></param> | ||||
| /// <param name="instructionSuffix"></param> | /// <param name="instructionSuffix"></param> | ||||
| public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", | |||||
| string instructionSuffix = "\n\n### Response:\n\n") : base(model) | |||||
| public InstructExecutor(LLamaContext context, string instructionPrefix = "\n\n### Instruction:\n\n", | |||||
| string instructionSuffix = "\n\n### Response:\n\n") : base(context) | |||||
| { | { | ||||
| _inp_pfx = _model.Tokenize(instructionPrefix, true); | |||||
| _inp_sfx = _model.Tokenize(instructionSuffix, false); | |||||
| _inp_pfx = Context.Tokenize(instructionPrefix, true); | |||||
| _inp_sfx = Context.Tokenize(instructionSuffix, false); | |||||
| _instructionPrefix = instructionPrefix; | _instructionPrefix = instructionPrefix; | ||||
| } | } | ||||
| @@ -84,16 +85,16 @@ namespace LLama | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void SaveState(string filename) | public override void SaveState(string filename) | ||||
| { | { | ||||
| InstructExecutorState state = GetStateData() as InstructExecutorState; | |||||
| using (FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | |||||
| var state = (InstructExecutorState)GetStateData(); | |||||
| using (var fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | |||||
| { | { | ||||
| JsonSerializer.Serialize<InstructExecutorState>(fs, state); | |||||
| JsonSerializer.Serialize(fs, state); | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void LoadState(string filename) | public override void LoadState(string filename) | ||||
| { | { | ||||
| using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||||
| { | { | ||||
| var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | ||||
| LoadState(state); | LoadState(state); | ||||
| @@ -108,16 +109,12 @@ namespace LLama | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override void PreprocessInputs(string text, InferStateArgs args) | protected override void PreprocessInputs(string text, InferStateArgs args) | ||||
| { | { | ||||
| if(args.Antiprompts is null) | |||||
| { | |||||
| args.Antiprompts = new List<string>(); | |||||
| } | |||||
| args.Antiprompts ??= new List<string>(); | |||||
| args.Antiprompts.Add(_instructionPrefix); | args.Antiprompts.Add(_instructionPrefix); | ||||
| if (_is_prompt_run) | if (_is_prompt_run) | ||||
| { | { | ||||
| // When running the first input (prompt) in inteactive mode, we should specially process it. | // When running the first input (prompt) in inteactive mode, we should specially process it. | ||||
| text = " " + text; | |||||
| _embed_inps = _model.Tokenize(text, true).ToList(); | |||||
| _embed_inps = Context.Tokenize(text, true).ToList(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -128,7 +125,7 @@ namespace LLama | |||||
| _consumedTokensCount = _embed_inps.Count; | _consumedTokensCount = _embed_inps.Count; | ||||
| _embed_inps.AddRange(_inp_pfx); | _embed_inps.AddRange(_inp_pfx); | ||||
| var line_inp = _model.Tokenize(text, false); | |||||
| var line_inp = Context.Tokenize(text, false); | |||||
| _embed_inps.AddRange(line_inp); | _embed_inps.AddRange(line_inp); | ||||
| _embed_inps.AddRange(_inp_sfx); | _embed_inps.AddRange(_inp_sfx); | ||||
| @@ -144,9 +141,10 @@ namespace LLama | |||||
| { | { | ||||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | ||||
| { | { | ||||
| string last_output = ""; | |||||
| foreach (var id in _last_n_tokens) | |||||
| last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); | |||||
| var last_output_builder = new StringBuilder(); | |||||
| foreach (var token in _last_n_tokens) | |||||
| Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); | |||||
| var last_output = last_output_builder.ToString(); | |||||
| foreach (var antiprompt in args.Antiprompts) | foreach (var antiprompt in args.Antiprompts) | ||||
| { | { | ||||
| @@ -160,12 +158,12 @@ namespace LLama | |||||
| if (_pastTokensCount > 0 && args.WaitForInput) | if (_pastTokensCount > 0 && args.WaitForInput) | ||||
| { | { | ||||
| extraOutputs = new string[] { "\n> " }; | |||||
| extraOutputs = new[] { "\n> " }; | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) | |||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||||
| { | { | ||||
| args.WaitForInput = true; | args.WaitForInput = true; | ||||
| } | } | ||||
| @@ -183,13 +181,13 @@ namespace LLama | |||||
| if (_embeds.Count > 0) | if (_embeds.Count > 0) | ||||
| { | { | ||||
| _is_prompt_run = false; | _is_prompt_run = false; | ||||
| if (_pastTokensCount + _embeds.Count > _model.ContextSize) | |||||
| if (_pastTokensCount + _embeds.Count > Context.ContextSize) | |||||
| { | { | ||||
| HandleRunOutOfContext(inferenceParams.TokensKeep); | HandleRunOutOfContext(inferenceParams.TokensKeep); | ||||
| } | } | ||||
| TryReuseMathingPrefix(); | TryReuseMathingPrefix(); | ||||
| _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); | |||||
| _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); | |||||
| if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | ||||
| { | { | ||||
| @@ -202,7 +200,7 @@ namespace LLama | |||||
| if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) | if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) | ||||
| { | { | ||||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; | |||||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; | |||||
| // optionally save the session on first sample (for faster prompt loading next time) | // optionally save the session on first sample (for faster prompt loading next time) | ||||
| if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) | if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) | ||||
| @@ -211,13 +209,14 @@ namespace LLama | |||||
| SaveSessionFile(_pathSession); | SaveSessionFile(_pathSession); | ||||
| } | } | ||||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| var mu = MirostatMu; | var mu = MirostatMu; | ||||
| var id = _model.Sample( | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | ||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||||
| inferenceParams.Grammar | |||||
| ); | ); | ||||
| MirostatMu = mu; | MirostatMu = mu; | ||||
| @@ -235,7 +234,7 @@ namespace LLama | |||||
| _embeds.Add(_embed_inps[_consumedTokensCount]); | _embeds.Add(_embed_inps[_consumedTokensCount]); | ||||
| _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | ||||
| _consumedTokensCount++; | _consumedTokensCount++; | ||||
| if (_embeds.Count >= _model.Params.BatchSize) | |||||
| if (_embeds.Count >= Context.Params.BatchSize) | |||||
| { | { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -7,6 +7,7 @@ using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text.Json; | using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| using System.Text; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -22,10 +23,10 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model"></param> | |||||
| public InteractiveExecutor(LLamaModel model) : base(model) | |||||
| /// <param name="context"></param> | |||||
| public InteractiveExecutor(LLamaContext context) : base(context) | |||||
| { | { | ||||
| _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); | |||||
| _llama_token_newline = new [] { NativeApi.llama_token_nl(Context.NativeHandle) }; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -72,10 +73,10 @@ namespace LLama | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void SaveState(string filename) | public override void SaveState(string filename) | ||||
| { | { | ||||
| InteractiveExecutorState state = GetStateData() as InteractiveExecutorState; | |||||
| InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); | |||||
| using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) | ||||
| { | { | ||||
| JsonSerializer.Serialize<InteractiveExecutorState>(fs, state); | |||||
| JsonSerializer.Serialize(fs, state); | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -103,8 +104,7 @@ namespace LLama | |||||
| if (_is_prompt_run) | if (_is_prompt_run) | ||||
| { | { | ||||
| // When running the first input (prompt) in inteactive mode, we should specially process it. | // When running the first input (prompt) in inteactive mode, we should specially process it. | ||||
| text = " " + text; | |||||
| _embed_inps = _model.Tokenize(text, true).ToList(); | |||||
| _embed_inps = Context.Tokenize(text, true).ToList(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -112,7 +112,7 @@ namespace LLama | |||||
| { | { | ||||
| text += "\n"; | text += "\n"; | ||||
| } | } | ||||
| var line_inp = _model.Tokenize(text, false); | |||||
| var line_inp = Context.Tokenize(text, false); | |||||
| _embed_inps.AddRange(line_inp); | _embed_inps.AddRange(line_inp); | ||||
| args.RemainedTokens -= line_inp.Length; | args.RemainedTokens -= line_inp.Length; | ||||
| } | } | ||||
| @@ -121,7 +121,9 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// Return whether to break the generation. | /// Return whether to break the generation. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="inferenceParams"></param> | |||||
| /// <param name="args"></param> | /// <param name="args"></param> | ||||
| /// <param name="extraOutputs"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | ||||
| { | { | ||||
| @@ -130,11 +132,10 @@ namespace LLama | |||||
| { | { | ||||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | ||||
| { | { | ||||
| string last_output = ""; | |||||
| foreach (var id in _last_n_tokens) | |||||
| { | |||||
| last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); | |||||
| } | |||||
| var last_output_builder = new StringBuilder(); | |||||
| foreach (var token in _last_n_tokens) | |||||
| Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); | |||||
| var last_output = last_output_builder.ToString(); | |||||
| foreach (var antiprompt in args.Antiprompts) | foreach (var antiprompt in args.Antiprompts) | ||||
| { | { | ||||
| @@ -152,9 +153,9 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) | |||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||||
| { | { | ||||
| extraOutputs = new string[] { " [end of text]\n" }; | |||||
| extraOutputs = new[] { " [end of text]\n" }; | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -172,13 +173,13 @@ namespace LLama | |||||
| if (_embeds.Count > 0) | if (_embeds.Count > 0) | ||||
| { | { | ||||
| _is_prompt_run = false; | _is_prompt_run = false; | ||||
| if (_pastTokensCount + _embeds.Count > _model.ContextSize) | |||||
| if (_pastTokensCount + _embeds.Count > Context.ContextSize) | |||||
| { | { | ||||
| HandleRunOutOfContext(inferenceParams.TokensKeep); | HandleRunOutOfContext(inferenceParams.TokensKeep); | ||||
| } | } | ||||
| TryReuseMathingPrefix(); | TryReuseMathingPrefix(); | ||||
| _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); | |||||
| _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); | |||||
| if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | ||||
| { | { | ||||
| @@ -191,7 +192,7 @@ namespace LLama | |||||
| if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) | if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) | ||||
| { | { | ||||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; | |||||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; | |||||
| // optionally save the session on first sample (for faster prompt loading next time) | // optionally save the session on first sample (for faster prompt loading next time) | ||||
| if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) | if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) | ||||
| @@ -200,24 +201,25 @@ namespace LLama | |||||
| SaveSessionFile(_pathSession); | SaveSessionFile(_pathSession); | ||||
| } | } | ||||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| var mu = MirostatMu; | var mu = MirostatMu; | ||||
| var id = _model.Sample( | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | ||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||||
| inferenceParams.Grammar | |||||
| ); | ); | ||||
| MirostatMu = mu; | MirostatMu = mu; | ||||
| _last_n_tokens.Enqueue(id); | _last_n_tokens.Enqueue(id); | ||||
| if (id == NativeApi.llama_token_eos()) | |||||
| if (id == NativeApi.llama_token_eos(Context.NativeHandle)) | |||||
| { | { | ||||
| id = _llama_token_newline.First(); | id = _llama_token_newline.First(); | ||||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | ||||
| { | { | ||||
| var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false); | |||||
| var first_antiprompt = Context.Tokenize(args.Antiprompts[0], false); | |||||
| _embed_inps.AddRange(first_antiprompt); | _embed_inps.AddRange(first_antiprompt); | ||||
| } | } | ||||
| } | } | ||||
| @@ -234,7 +236,7 @@ namespace LLama | |||||
| _embeds.Add(_embed_inps[_consumedTokensCount]); | _embeds.Add(_embed_inps[_consumedTokensCount]); | ||||
| _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | ||||
| _consumedTokensCount++; | _consumedTokensCount++; | ||||
| if (_embeds.Count >= _model.Params.BatchSize) | |||||
| if (_embeds.Count >= Context.Params.BatchSize) | |||||
| { | { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -1,8 +1,6 @@ | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -36,8 +34,7 @@ namespace LLama | |||||
| quantizeParams.nthread = nthread; | quantizeParams.nthread = nthread; | ||||
| quantizeParams.allow_requantize = allowRequantize; | quantizeParams.allow_requantize = allowRequantize; | ||||
| quantizeParams.quantize_output_tensor = quantizeOutputTensor; | quantizeParams.quantize_output_tensor = quantizeOutputTensor; | ||||
| LLamaModelQuantizeParams* p = &quantizeParams; | |||||
| return NativeApi.llama_model_quantize(srcFileName, dstFilename, p) == 0; | |||||
| return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -57,42 +54,71 @@ namespace LLama | |||||
| return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread, allowRequantize, quantizeOutputTensor); | return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread, allowRequantize, quantizeOutputTensor); | ||||
| } | } | ||||
| private static bool ValidateFtype(string ftype) | |||||
| { | |||||
| return new string[] { "q4_0", "q4_1", "q5_0", "q5_1", "q8_0" }.Contains(ftype); | |||||
| } | |||||
| private static bool ValidateFtype(LLamaFtype ftype) | private static bool ValidateFtype(LLamaFtype ftype) | ||||
| { | { | ||||
| return ftype is LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1 | |||||
| or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0; | |||||
| } | |||||
| // Validation copies from here: | |||||
| // https://github.com/ggerganov/llama.cpp/blob/e59fcb2bc129881f4a269fee748fb38bce0a64de/llama.cpp#L2960 | |||||
| private static string FtypeToString(LLamaFtype ftype) | |||||
| { | |||||
| return ftype switch | |||||
| switch (ftype) | |||||
| { | { | ||||
| LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0 => "q4_0", | |||||
| LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1 => "q4_1", | |||||
| LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0 => "q5_0", | |||||
| LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1 => "q5_1", | |||||
| LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0 => "q8_0", | |||||
| _ => throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " + | |||||
| $"to perform quantization.") | |||||
| }; | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_F16: | |||||
| case LLamaFtype.LLAMA_FTYPE_ALL_F32: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_K_S: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_K_M: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_K_S: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_K_M: | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q6_K: | |||||
| return true; | |||||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: | |||||
| default: | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Parse a string into a LLamaFtype. This is a "relaxed" parsing, which allows any string which is contained within | |||||
| /// the enum name to be used. | |||||
| /// | |||||
| /// For example "Q5_K_M" will convert to "LLAMA_FTYPE_MOSTLY_Q5_K_M" | |||||
| /// </summary> | |||||
| /// <param name="str"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException"></exception> | |||||
| private static LLamaFtype StringToFtype(string str) | private static LLamaFtype StringToFtype(string str) | ||||
| { | { | ||||
| return str switch | |||||
| // Find all variants which contain the input string | |||||
| var matches = new List<LLamaFtype>(); | |||||
| foreach (LLamaFtype ftype in Enum.GetValues(typeof(LLamaFtype))) | |||||
| { | { | ||||
| "q4_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0, | |||||
| "q4_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1, | |||||
| "q5_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0, | |||||
| "q5_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1, | |||||
| "q8_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0, | |||||
| _ => throw new ArgumentException($"Invalid ftype {str} to quantize.") | |||||
| }; | |||||
| var name = Enum.GetName(typeof(LLamaFtype), ftype); | |||||
| // Note: this is using "IndexOf" instead of "Contains" to be compatible with netstandard2.0 | |||||
| #pragma warning disable CA2249 | |||||
| if (name != null && name.IndexOf(str, StringComparison.OrdinalIgnoreCase) >= 0) | |||||
| matches.Add(ftype); | |||||
| #pragma warning restore CA2249 | |||||
| } | |||||
| // If there was just one match, success! | |||||
| if (matches.Count == 1) | |||||
| return matches[0]; | |||||
| // If none matched throw a generic error | |||||
| if (matches.Count == 0) | |||||
| throw new ArgumentException($"Unknown ftype \"{str}\" for quantization."); | |||||
| // There were several matches, throw an error asking the user to be more specific | |||||
| throw new ArgumentException($"\"{str}\" matches multiple potential ftypes: {string.Join(",", matches)}"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -32,11 +32,11 @@ | |||||
| <Link>libllama.dylib</Link> | <Link>libllama.dylib</Link> | ||||
| </None> | </None> | ||||
| <None Include="$(MSBuildThisFileDirectory)runtimes/libllama-metal.dylib"> | <None Include="$(MSBuildThisFileDirectory)runtimes/libllama-metal.dylib"> | ||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
| <CopyToOutputDirectory>None</CopyToOutputDirectory> | |||||
| <Link>libllama-metal.dylib</Link> | <Link>libllama-metal.dylib</Link> | ||||
| </None> | </None> | ||||
| <None Include="$(MSBuildThisFileDirectory)runtimes/ggml-metal.metal"> | <None Include="$(MSBuildThisFileDirectory)runtimes/ggml-metal.metal"> | ||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
| <CopyToOutputDirectory>None</CopyToOutputDirectory> | |||||
| <Link>ggml-metal.metal</Link> | <Link>ggml-metal.metal</Link> | ||||
| </None> | </None> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -1,125 +1,159 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | |||||
| using System.Threading; | using System.Threading; | ||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| /// <summary> | /// <summary> | ||||
| /// This executor infer the input as one-time job. Previous inputs won't impact on the | /// This executor infer the input as one-time job. Previous inputs won't impact on the | ||||
| /// response to current input. | /// response to current input. | ||||
| /// </summary> | /// </summary> | ||||
| public class StatelessExecutor : ILLamaExecutor | |||||
| public class StatelessExecutor | |||||
| : ILLamaExecutor | |||||
| { | { | ||||
| private LLamaModel _model; | |||||
| private LLamaModel.State _originalState; | |||||
| private readonly LLamaWeights _weights; | |||||
| private readonly IModelParams _params; | |||||
| /// <summary> | |||||
| /// The context used by the executor when running the inference. | |||||
| /// </summary> | |||||
| public LLamaContext Context { get; private set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// The mode used by the executor when running the inference. | |||||
| /// Create a new stateless executor which will use the given model | |||||
| /// </summary> | /// </summary> | ||||
| public LLamaModel Model => _model; | |||||
| /// <param name="weights"></param> | |||||
| /// <param name="params"></param> | |||||
| public StatelessExecutor(LLamaWeights weights, IModelParams @params) | |||||
| { | |||||
| _weights = weights; | |||||
| _params = @params; | |||||
| Context = _weights.CreateContext(_params); | |||||
| Context.Dispose(); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// | |||||
| /// Create a new stateless executor which will use the model used to create the given context | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="model">The LLama model.</param> | |||||
| public StatelessExecutor(LLamaModel model) | |||||
| /// <param name="context"></param> | |||||
| [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] | |||||
| public StatelessExecutor(LLamaContext context) | |||||
| { | { | ||||
| _model = model; | |||||
| var tokens = model.Tokenize(" ", true).ToArray(); | |||||
| _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads); | |||||
| _originalState = model.GetState(); | |||||
| _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); | |||||
| _params = context.Params; | |||||
| Context = _weights.CreateContext(_params); | |||||
| Context.Dispose(); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| cancellationToken.ThrowIfCancellationRequested(); | |||||
| int n_past = 1; | |||||
| if(inferenceParams is null) | |||||
| { | |||||
| inferenceParams = new InferenceParams(); | |||||
| } | |||||
| List<llama_token> lastTokens = new(inferenceParams.RepeatLastTokensCount); | |||||
| for(int i = 0; i < lastTokens.Count; i++) | |||||
| using var context = _weights.CreateContext(_params); | |||||
| Context = context; | |||||
| if (!Context.NativeHandle.IsClosed) | |||||
| Context.Dispose(); | |||||
| Context = _weights.CreateContext(Context.Params); | |||||
| if (inferenceParams != null) | |||||
| { | { | ||||
| lastTokens[i] = 0; | |||||
| if (inferenceParams.TokensKeep > Context.ContextSize) | |||||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); | |||||
| } | } | ||||
| List<llama_token> tokens = _model.Tokenize(text, true).ToList(); | |||||
| int n_prompt_tokens = tokens.Count; | |||||
| _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads); | |||||
| cancellationToken.ThrowIfCancellationRequested(); | |||||
| var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>(); | |||||
| var n_past = 1; | |||||
| inferenceParams ??= new InferenceParams(); | |||||
| var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount); | |||||
| for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) | |||||
| lastTokens.Add(0); | |||||
| var tokens = Context.Tokenize(text).ToList(); | |||||
| var n_prompt_tokens = tokens.Count; | |||||
| Context.Eval(tokens, n_past); | |||||
| lastTokens.AddRange(tokens); | lastTokens.AddRange(tokens); | ||||
| n_past += n_prompt_tokens; | n_past += n_prompt_tokens; | ||||
| var mu = (float?)null; | var mu = (float?)null; | ||||
| int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||||
| for(int i = 0; i < max_tokens; i++) | |||||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||||
| for(var i = 0; i < max_tokens; i++) | |||||
| { | { | ||||
| if (cancellationToken.IsCancellationRequested) | if (cancellationToken.IsCancellationRequested) | ||||
| { | |||||
| _model.LoadState(_originalState); | |||||
| break; | break; | ||||
| } | |||||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; | |||||
| var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; | |||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); | |||||
| var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); | |||||
| lastTokens.Add(id); | lastTokens.Add(id); | ||||
| string response = _model.NativeHandle.TokenToString(id, _model.Encoding); | |||||
| var response = Context.TokenToString(id); | |||||
| yield return response; | yield return response; | ||||
| tokens.Clear(); | tokens.Clear(); | ||||
| tokens.Add(id); | tokens.Add(id); | ||||
| if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0) | |||||
| { | |||||
| string last_output = ""; | |||||
| foreach (var token in lastTokens) | |||||
| { | |||||
| last_output += _model.NativeHandle.TokenToString(token, _model.Encoding); | |||||
| } | |||||
| bool should_break = false; | |||||
| foreach (var antiprompt in inferenceParams.AntiPrompts) | |||||
| { | |||||
| if (last_output.EndsWith(antiprompt)) | |||||
| { | |||||
| should_break = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (should_break) | |||||
| { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (EndsWithAntiprompt(lastTokens, antiprompts)) | |||||
| break; | |||||
| // when run out of context | // when run out of context | ||||
| if (n_past + tokens.Count > _model.ContextSize) | |||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 | |||||
| if (n_past + tokens.Count > Context.ContextSize) | |||||
| { | { | ||||
| int n_left = n_past - inferenceParams.TokensKeep; | |||||
| var n_left = n_past - inferenceParams.TokensKeep; | |||||
| n_past = Math.Max(1, inferenceParams.TokensKeep); | n_past = Math.Max(1, inferenceParams.TokensKeep); | ||||
| // insert n_left/2 tokens at the start of embed from last_n_tokens | |||||
| tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_model.ContextSize - n_left / 2 - tokens.Count)); | |||||
| tokens.Clear(); | |||||
| tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); | |||||
| } | } | ||||
| n_past = _model.Eval(tokens.ToArray(), n_past); | |||||
| n_past = Context.Eval(tokens, n_past); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Check if the given tokens list ends with any of the antiprompts | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="antiprompts"></param> | |||||
| /// <returns></returns> | |||||
| private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts) | |||||
| { | |||||
| if (antiprompts.Count == 0 || tokens.Count == 0) | |||||
| return false; | |||||
| var builder = new StringBuilder(); | |||||
| foreach (var token in tokens) | |||||
| builder.Append(Context.TokenToString(token)); | |||||
| var last_output = builder.ToString(); | |||||
| foreach (var antiprompt in antiprompts) | |||||
| { | |||||
| if (last_output.EndsWith(antiprompt)) | |||||
| return true; | |||||
| } | } | ||||
| _model.LoadState(_originalState); | |||||
| return false; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -0,0 +1,81 @@ | |||||
| using System; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Extensions; | |||||
| using LLama.Native; | |||||
| namespace LLama | |||||
| { | |||||
| /// <summary> | |||||
| /// A set of model weights, loaded into memory. | |||||
| /// </summary> | |||||
| public sealed class LLamaWeights | |||||
| : IDisposable | |||||
| { | |||||
| private readonly SafeLlamaModelHandle _weights; | |||||
| /// <summary> | |||||
| /// The native handle, which is used in the native APIs | |||||
| /// </summary> | |||||
| /// <remarks>Be careful how you use this!</remarks> | |||||
| public SafeLlamaModelHandle NativeHandle => _weights; | |||||
| /// <summary> | |||||
| /// Encoding to use to convert text into bytes for the model | |||||
| /// </summary> | |||||
| public Encoding Encoding { get; } | |||||
| /// <summary> | |||||
| /// Total number of tokens in vocabulary of this model | |||||
| /// </summary> | |||||
| public int VocabCount => NativeHandle.VocabCount; | |||||
| /// <summary> | |||||
| /// Total number of tokens in the context | |||||
| /// </summary> | |||||
| public int ContextSize => NativeHandle.ContextSize; | |||||
| /// <summary> | |||||
| /// Dimension of embedding vectors | |||||
| /// </summary> | |||||
| public int EmbeddingSize => NativeHandle.EmbeddingSize; | |||||
| internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) | |||||
| { | |||||
| _weights = weights; | |||||
| Encoding = encoding; | |||||
| } | |||||
| /// <summary> | |||||
| /// Load weights into memory | |||||
| /// </summary> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | |||||
| public static LLamaWeights LoadFromFile(IModelParams @params) | |||||
| { | |||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||||
| var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | |||||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | |||||
| weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | |||||
| return new LLamaWeights(weights, @params.Encoding); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public void Dispose() | |||||
| { | |||||
| _weights.Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a llama_context using this model | |||||
| /// </summary> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | |||||
| public LLamaContext CreateContext(IModelParams @params) | |||||
| { | |||||
| return new LLamaContext(this, @params); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,15 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace LLama.Native | |||||
| { | |||||
| internal struct GgmlInitParams | |||||
| { | |||||
| public ulong mem_size; | |||||
| public IntPtr mem_buffer; | |||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool no_alloc; | |||||
| } | |||||
| } | |||||
| @@ -1,11 +1,18 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Called by llama.cpp with a progress value between 0 and 1 | |||||
| /// </summary> | |||||
| /// <param name="progress"></param> | |||||
| /// <param name="ctx"></param> | |||||
| public delegate void LlamaProgressCallback(float progress, IntPtr ctx); | public delegate void LlamaProgressCallback(float progress, IntPtr ctx); | ||||
| /// <summary> | |||||
| /// A C# representation of the llama.cpp `llama_context_params` struct | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| public struct LLamaContextParams | public struct LLamaContextParams | ||||
| { | { | ||||
| @@ -24,16 +31,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public int n_batch; | public int n_batch; | ||||
| /// <summary> | |||||
| /// grouped-query attention (TEMP - will be moved to model hparams) | |||||
| /// </summary> | |||||
| public int n_gqa; | |||||
| /// <summary> | |||||
| /// rms norm epsilon (TEMP - will be moved to model hparams) | |||||
| /// </summary> | |||||
| public float rms_norm_eps; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of layers to store in VRAM | /// number of layers to store in VRAM | ||||
| /// </summary> | /// </summary> | ||||
| @@ -49,7 +46,6 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public nint tensor_split; | public nint tensor_split; | ||||
| /// <summary> | /// <summary> | ||||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | ||||
| /// RoPE base frequency | /// RoPE base frequency | ||||
| @@ -72,53 +68,85 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| public IntPtr progress_callback_user_data; | public IntPtr progress_callback_user_data; | ||||
| /// <summary> | /// <summary> | ||||
| /// if true, reduce VRAM usage at the cost of performance | /// if true, reduce VRAM usage at the cost of performance | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool low_vram; | |||||
| public bool low_vram | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_low_vram); | |||||
| set => _low_vram = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _low_vram; | |||||
| /// <summary> | /// <summary> | ||||
| /// if true, use experimental mul_mat_q kernels | /// if true, use experimental mul_mat_q kernels | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] public bool mul_mat_q; | |||||
| public bool mul_mat_q | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_mul_mat_q); | |||||
| set => _mul_mat_q = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _mul_mat_q; | |||||
| /// <summary> | /// <summary> | ||||
| /// use fp16 for KV cache | /// use fp16 for KV cache | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool f16_kv; | |||||
| public bool f16_kv | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_f16_kv); | |||||
| set => _f16_kv = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _f16_kv; | |||||
| /// <summary> | /// <summary> | ||||
| /// the llama_eval() call computes all logits, not just the last one | /// the llama_eval() call computes all logits, not just the last one | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool logits_all; | |||||
| public bool logits_all | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_logits_all); | |||||
| set => _logits_all = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _logits_all; | |||||
| /// <summary> | /// <summary> | ||||
| /// only load the vocabulary, no weights | /// only load the vocabulary, no weights | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool vocab_only; | |||||
| public bool vocab_only | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_vocab_only); | |||||
| set => _vocab_only = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _vocab_only; | |||||
| /// <summary> | /// <summary> | ||||
| /// use mmap if possible | /// use mmap if possible | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool use_mmap; | |||||
| public bool use_mmap | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mmap); | |||||
| set => _use_mmap = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mmap; | |||||
| /// <summary> | /// <summary> | ||||
| /// force system to keep model in RAM | /// force system to keep model in RAM | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool use_mlock; | |||||
| public bool use_mlock | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_use_mlock); | |||||
| set => _use_mlock = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _use_mlock; | |||||
| /// <summary> | /// <summary> | ||||
| /// embedding mode only | /// embedding mode only | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool embedding; | |||||
| public bool embedding | |||||
| { | |||||
| readonly get => Convert.ToBoolean(_embedding); | |||||
| set => _embedding = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _embedding; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,29 +1,114 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Native | |||||
| namespace LLama.Native | |||||
| { | { | ||||
| /// <summary> | |||||
| /// Supported model file types | |||||
| /// </summary> | |||||
| public enum LLamaFtype | public enum LLamaFtype | ||||
| { | { | ||||
| /// <summary> | |||||
| /// All f32 | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 26GB</remarks> | |||||
| LLAMA_FTYPE_ALL_F32 = 0, | LLAMA_FTYPE_ALL_F32 = 0, | ||||
| LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 | |||||
| // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed | |||||
| // LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed | |||||
| LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors | |||||
| LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors | |||||
| /// <summary> | |||||
| /// Mostly f16 | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 13GB</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_F16 = 1, | |||||
| /// <summary> | |||||
| /// Mostly 8 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 6.7GB, +0.0004ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q8_0 = 7, | |||||
| /// <summary> | |||||
| /// Mostly 4 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 3.50GB, +0.2499 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q4_0 = 2, | |||||
| /// <summary> | |||||
| /// Mostly 4 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 3.90GB, +0.1846 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q4_1 = 3, | |||||
| /// <summary> | |||||
| /// Mostly 4 bit, tok_embeddings.weight and output.weight are f16 | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, | |||||
| /// <summary> | |||||
| /// Mostly 5 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 4.30GB @ 7B tokens, +0.0796 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q5_0 = 8, | |||||
| /// <summary> | |||||
| /// Mostly 5 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 4.70GB, +0.0415 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q5_1 = 9, | |||||
| /// <summary> | |||||
| /// K-Quant 2 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 2.67GB @ 7N parameters, +0.8698 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q2_K = 10, | |||||
| /// <summary> | |||||
| /// K-Quant 3 bit (Small) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 2.75GB, +0.5505 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, | |||||
| /// <summary> | |||||
| /// K-Quant 3 bit (Medium) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 3.06GB, +0.2437 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, | |||||
| /// <summary> | |||||
| /// K-Quant 3 bit (Large) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 3.35GB, +0.1803 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, | |||||
| /// <summary> | |||||
| /// K-Quant 4 bit (Small) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 3.56GB, +0.1149 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, | |||||
| /// <summary> | |||||
| /// K-Quant 4 bit (Medium) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 3.80GB, +0.0535 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, | |||||
| /// <summary> | |||||
| /// K-Quant 5 bit (Small) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 4.33GB, +0.0353 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, | |||||
| /// <summary> | |||||
| /// K-Quant 5 bit (Medium) | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 4.45GB, +0.0142 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, | |||||
| /// <summary> | |||||
| /// K-Quant 6 bit | |||||
| /// </summary> | |||||
| /// <remarks>Benchmark@7B: 5.15GB, +0.0044 ppl</remarks> | |||||
| LLAMA_FTYPE_MOSTLY_Q6_K = 18, | |||||
| /// <summary> | |||||
| /// File type was not specified | |||||
| /// </summary> | |||||
| LLAMA_FTYPE_GUESSED = 1024 | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,124 @@ | |||||
| using System; | |||||
| using System.Diagnostics; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | |||||
| { | |||||
| /// <summary> | |||||
| /// grammar element type | |||||
| /// </summary> | |||||
| public enum LLamaGrammarElementType | |||||
| { | |||||
| /// <summary> | |||||
| /// end of rule definition | |||||
| /// </summary> | |||||
| END = 0, | |||||
| /// <summary> | |||||
| /// start of alternate definition for rule | |||||
| /// </summary> | |||||
| ALT = 1, | |||||
| /// <summary> | |||||
| /// non-terminal element: reference to rule | |||||
| /// </summary> | |||||
| RULE_REF = 2, | |||||
| /// <summary> | |||||
| /// terminal element: character (code point) | |||||
| /// </summary> | |||||
| CHAR = 3, | |||||
| /// <summary> | |||||
| /// inverse char(s) ([^a], [^a-b] [^abc]) | |||||
| /// </summary> | |||||
| CHAR_NOT = 4, | |||||
| /// <summary> | |||||
| /// modifies a preceding CHAR or CHAR_ALT to | |||||
| /// be an inclusive range ([a-z]) | |||||
| /// </summary> | |||||
| CHAR_RNG_UPPER = 5, | |||||
| /// <summary> | |||||
| /// modifies a preceding CHAR or | |||||
| /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |||||
| /// </summary> | |||||
| CHAR_ALT = 6, | |||||
| }; | |||||
| /// <summary> | |||||
| /// An element of a grammar | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| [DebuggerDisplay("{Type} {Value}")] | |||||
| public readonly struct LLamaGrammarElement | |||||
| : IEquatable<LLamaGrammarElement> | |||||
| { | |||||
| /// <summary> | |||||
| /// The type of this element | |||||
| /// </summary> | |||||
| public readonly LLamaGrammarElementType Type; | |||||
| /// <summary> | |||||
| /// Unicode code point or rule ID | |||||
| /// </summary> | |||||
| public readonly uint Value; | |||||
| /// <summary> | |||||
| /// Construct a new LLamaGrammarElement | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <param name="value"></param> | |||||
| public LLamaGrammarElement(LLamaGrammarElementType type, uint value) | |||||
| { | |||||
| Type = type; | |||||
| Value = value; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public bool Equals(LLamaGrammarElement other) | |||||
| { | |||||
| if (Type != other.Type) | |||||
| return false; | |||||
| // No need to compare values for the END rule | |||||
| if (Type == LLamaGrammarElementType.END) | |||||
| return true; | |||||
| return Value == other.Value; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override bool Equals(object? obj) | |||||
| { | |||||
| return obj is LLamaGrammarElement other && Equals(other); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override int GetHashCode() | |||||
| { | |||||
| unchecked | |||||
| { | |||||
| var hash = 2999; | |||||
| hash = hash * 7723 + (int)Type; | |||||
| hash = hash * 7723 + (int)Value; | |||||
| return hash; | |||||
| } | |||||
| } | |||||
| internal bool IsCharElement() | |||||
| { | |||||
| switch (Type) | |||||
| { | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| return true; | |||||
| default: | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,29 +1,40 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Quantizer parameters used in the native API | |||||
| /// </summary> | |||||
| public struct LLamaModelQuantizeParams | public struct LLamaModelQuantizeParams | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() | /// number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() | ||||
| /// </summary> | /// </summary> | ||||
| public int nthread; | public int nthread; | ||||
| /// <summary> | /// <summary> | ||||
| /// quantize to this llama_ftype | /// quantize to this llama_ftype | ||||
| /// </summary> | /// </summary> | ||||
| public LLamaFtype ftype; | public LLamaFtype ftype; | ||||
| /// <summary> | /// <summary> | ||||
| /// allow quantizing non-f32/f16 tensors | /// allow quantizing non-f32/f16 tensors | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool allow_requantize; | |||||
| public bool allow_requantize | |||||
| { | |||||
| get => Convert.ToBoolean(_allow_requantize); | |||||
| set => _allow_requantize = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _allow_requantize; | |||||
| /// <summary> | /// <summary> | ||||
| /// quantize output.weight | /// quantize output.weight | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool quantize_output_tensor; | |||||
| public bool quantize_output_tensor | |||||
| { | |||||
| get => Convert.ToBoolean(_quantize_output_tensor); | |||||
| set => _quantize_output_tensor = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _quantize_output_tensor; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| @@ -2,6 +2,8 @@ | |||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using llama_token = System.Int32; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -15,9 +17,9 @@ namespace LLama.Native | |||||
| public readonly Memory<LLamaTokenData> data; | public readonly Memory<LLamaTokenData> data; | ||||
| /// <summary> | /// <summary> | ||||
| /// Indicates if `data` is sorted | |||||
| /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_. | |||||
| /// </summary> | /// </summary> | ||||
| public readonly bool sorted; | |||||
| public bool sorted; | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new LLamaTokenDataArray | /// Create a new LLamaTokenDataArray | ||||
| @@ -29,6 +31,20 @@ namespace LLama.Native | |||||
| data = tokens; | data = tokens; | ||||
| sorted = isSorted; | sorted = isSorted; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a new LLamaTokenDataArray, copying the data from the given logits | |||||
| /// </summary> | |||||
| /// <param name="logits"></param> | |||||
| /// <returns></returns> | |||||
| public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits) | |||||
| { | |||||
| var candidates = new LLamaTokenData[logits.Length]; | |||||
| for (var token_id = 0; token_id < logits.Length; token_id++) | |||||
| candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); | |||||
| return new LLamaTokenDataArray(candidates); | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -51,8 +67,12 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Indicates if the items in the array are sorted | /// Indicates if the items in the array are sorted | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool sorted; | |||||
| public bool sorted | |||||
| { | |||||
| get => Convert.ToBoolean(_sorted); | |||||
| set => _sorted = Convert.ToSByte(value); | |||||
| } | |||||
| private sbyte _sorted; | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray | /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray | ||||
| @@ -0,0 +1,45 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | |||||
| { | |||||
| using llama_token = Int32; | |||||
| public unsafe partial class NativeApi | |||||
| { | |||||
| /// <summary> | |||||
| /// Create a new grammar from the given set of grammar rules | |||||
| /// </summary> | |||||
| /// <param name="rules"></param> | |||||
| /// <param name="n_rules"></param> | |||||
| /// <param name="start_rule_index"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); | |||||
| /// <summary> | |||||
| /// Free all memory from the given SafeLLamaGrammarHandle | |||||
| /// </summary> | |||||
| /// <param name="grammar"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_grammar_free(IntPtr grammar); | |||||
| /// <summary> | |||||
| /// Apply constraints from grammar | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <param name="grammar"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_sample_grammar(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaGrammarHandle grammar); | |||||
| /// <summary> | |||||
| /// Accepts the sampled token into the grammar | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="grammar"></param> | |||||
| /// <param name="token"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token); | |||||
| } | |||||
| } | |||||
| @@ -1,7 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| @@ -16,6 +13,6 @@ namespace LLama.Native | |||||
| /// <remarks>not great API - very likely to change</remarks> | /// <remarks>not great API - very likely to change</remarks> | ||||
| /// <returns>Returns 0 on success</returns> | /// <returns>Returns 0 on success</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public unsafe static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param); | |||||
| public static extern unsafe int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param); | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,6 +7,16 @@ namespace LLama.Native | |||||
| public unsafe partial class NativeApi | public unsafe partial class NativeApi | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.</param> | |||||
| /// <param name="guidanceCtx">A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param> | |||||
| /// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_sample_classifier_free_guidance(SafeLLamaContextHandle ctx, LLamaTokenDataArrayNative candidates, SafeLLamaContextHandle guidanceCtx, float scale); | |||||
| /// <summary> | /// <summary> | ||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -16,7 +26,7 @@ namespace LLama.Native | |||||
| /// <param name="last_tokens_size"></param> | /// <param name="last_tokens_size"></param> | ||||
| /// <param name="penalty"></param> | /// <param name="penalty"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); | |||||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); | |||||
| /// <summary> | /// <summary> | ||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | ||||
| @@ -28,7 +38,17 @@ namespace LLama.Native | |||||
| /// <param name="alpha_frequency"></param> | /// <param name="alpha_frequency"></param> | ||||
| /// <param name="alpha_presence"></param> | /// <param name="alpha_presence"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||||
| /// <summary> | |||||
| /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.</param> | |||||
| /// <param name="guidance_ctx">A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param> | |||||
| /// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_sample_classifier_free_guidance(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaContextHandle guidance_ctx, float scale); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | ||||
| @@ -98,7 +118,7 @@ namespace LLama.Native | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu); | |||||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); | |||||
| /// <summary> | /// <summary> | ||||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | ||||
| @@ -110,7 +130,7 @@ namespace LLama.Native | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu); | |||||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); | |||||
| /// <summary> | /// <summary> | ||||
| /// Selects the token with the highest probability. | /// Selects the token with the highest probability. | ||||
| @@ -1,14 +1,28 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Common; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| #pragma warning disable IDE1006 // Naming Styles | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public unsafe partial class NativeApi | |||||
| /// <summary> | |||||
| /// Callback from llama.cpp with log messages | |||||
| /// </summary> | |||||
| /// <param name="level"></param> | |||||
| /// <param name="message"></param> | |||||
| public delegate void LLamaLogCallback(ILLamaLogger.LogLevel level, string message); | |||||
| /// <summary> | |||||
| /// Direct translation of the llama.cpp API | |||||
| /// </summary> | |||||
| public unsafe partial class NativeApi | |||||
| { | { | ||||
| public static readonly int LLAMA_MAX_DEVICES = 1; | |||||
| static NativeApi() | static NativeApi() | ||||
| { | { | ||||
| try | try | ||||
| @@ -28,21 +42,50 @@ namespace LLama.Native | |||||
| } | } | ||||
| private const string libraryName = "libllama"; | private const string libraryName = "libllama"; | ||||
| /// <summary> | |||||
| /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_empty_call(); | public static extern bool llama_empty_call(); | ||||
| /// <summary> | |||||
| /// Create a LLamaContextParams with default values | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern LLamaContextParams llama_context_default_params(); | public static extern LLamaContextParams llama_context_default_params(); | ||||
| /// <summary> | |||||
| /// Create a LLamaModelQuantizeParams with default values | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern LLamaModelQuantizeParams llama_model_quantize_default_params(); | public static extern LLamaModelQuantizeParams llama_model_quantize_default_params(); | ||||
| /// <summary> | |||||
| /// Check if memory mapping is supported | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_mmap_supported(); | public static extern bool llama_mmap_supported(); | ||||
| /// <summary> | |||||
| /// Check if memory lockingis supported | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_mlock_supported(); | public static extern bool llama_mlock_supported(); | ||||
| /// <summary> | |||||
| /// Export a static computation graph for context of 511 and batch size of 1 | |||||
| /// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these | |||||
| /// parameters here to keep things simple | |||||
| /// IMPORTANT: do not use for anything else other than debugging and testing! | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="fname"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname); | public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname); | ||||
| @@ -52,13 +95,20 @@ namespace LLama.Native | |||||
| /// Return NULL on failure | /// Return NULL on failure | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="path_model"></param> | /// <param name="path_model"></param> | ||||
| /// <param name="params_"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_); | |||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params); | |||||
| /// <summary> | |||||
| /// Create a new llama_context with the given model. | |||||
| /// Return value should always be wrapped in SafeLLamaContextHandle! | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_); | |||||
| public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); | |||||
| /// <summary> | /// <summary> | ||||
| /// not great API - very likely to change. | /// not great API - very likely to change. | ||||
| @@ -69,7 +119,7 @@ namespace LLama.Native | |||||
| public static extern void llama_backend_init(bool numa); | public static extern void llama_backend_init(bool numa); | ||||
| /// <summary> | /// <summary> | ||||
| /// Frees all allocated memory | |||||
| /// Frees all allocated memory in the given llama_context | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| @@ -223,9 +273,6 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert the provided text into tokens. | /// Convert the provided text into tokens. | ||||
| /// The tokens pointer must be large enough to hold the resulting tokens. | |||||
| /// Returns the number of tokens on success, no more than n_max_tokens | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| @@ -233,35 +280,72 @@ namespace LLama.Native | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_max_tokens"></param> | /// <param name="n_max_tokens"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| /// <returns></returns> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||||
| /// </returns> | |||||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) | public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) | ||||
| { | { | ||||
| var bytes = encoding.GetBytes(text); | |||||
| sbyte[] data = new sbyte[bytes.Length]; | |||||
| for(int i = 0; i < bytes.Length; i++) | |||||
| // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) | |||||
| var byteCount = encoding.GetByteCount(text); | |||||
| var array = ArrayPool<byte>.Shared.Rent(byteCount + 1); | |||||
| try | |||||
| { | |||||
| // Convert to bytes | |||||
| fixed (char* textPtr = text) | |||||
| fixed (byte* arrayPtr = array) | |||||
| { | |||||
| encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); | |||||
| } | |||||
| // Add a zero byte to the end to terminate the string | |||||
| array[byteCount] = 0; | |||||
| // Do the actual tokenization | |||||
| fixed (byte* arrayPtr = array) | |||||
| fixed (llama_token* tokensPtr = tokens) | |||||
| return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos); | |||||
| } | |||||
| finally | |||||
| { | { | ||||
| data[i] = (sbyte)bytes[i]; | |||||
| //if (bytes[i] < 128) | |||||
| //{ | |||||
| // data[i] = (sbyte)bytes[i]; | |||||
| //} | |||||
| //else | |||||
| //{ | |||||
| // data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1); | |||||
| //} | |||||
| ArrayPool<byte>.Shared.Return(array); | |||||
| } | } | ||||
| return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Convert the provided text into tokens. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="n_max_tokens"></param> | |||||
| /// <param name="add_bos"></param> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||||
| /// </returns> | |||||
| [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos); | |||||
| public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos); | |||||
| /// <summary> | |||||
| /// Get the number of tokens in the model vocabulary for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); | public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Get the size of the context window for the model for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); | public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Get the dimension of embedding vectors from the model for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_embd(SafeLLamaContextHandle ctx); | public static extern int llama_n_embd(SafeLLamaContextHandle ctx); | ||||
| @@ -295,18 +379,38 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token); | public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token); | ||||
| /// <summary> | |||||
| /// Get the "Beginning of sentence" token | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_token_bos(); | |||||
| public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | |||||
| /// Get the "End of sentence" token | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_token_eos(); | |||||
| public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | |||||
| /// Get the "new line" token | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_token_nl(); | |||||
| public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | |||||
| /// Print out timing information for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_print_timings(SafeLLamaContextHandle ctx); | public static extern void llama_print_timings(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Reset all collected timing information for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_reset_timings(SafeLLamaContextHandle ctx); | public static extern void llama_reset_timings(SafeLLamaContextHandle ctx); | ||||
| @@ -317,19 +421,60 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_print_system_info(); | public static extern IntPtr llama_print_system_info(); | ||||
| /// <summary> | |||||
| /// Get the number of tokens in the model vocabulary | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model); | |||||
| public static extern int llama_model_n_vocab(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Get the size of the context window for the model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model); | |||||
| public static extern int llama_model_n_ctx(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Get the dimension of embedding vectors from this model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model); | |||||
| public static extern int llama_model_n_embd(SafeLlamaModelHandle model); | |||||
| /// <summary> | |||||
| /// Convert a single token into text | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="llamaToken"></param> | |||||
| /// <param name="buffer">buffer to write string into</param> | |||||
| /// <param name="length">size of the buffer</param> | |||||
| /// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken); | |||||
| public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); | |||||
| /// <summary> | |||||
| /// Convert text into tokens | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="n_max_tokens"></param> | |||||
| /// <param name="add_bos"></param> | |||||
| /// <returns>Returns the number of tokens on success, no more than n_max_tokens. | |||||
| /// Returns a negative number on failure - the number of tokens that would have been returned | |||||
| /// </returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); | public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); | ||||
| } | |||||
| /// <summary> | |||||
| /// Register a callback to receive llama log messages | |||||
| /// </summary> | |||||
| /// <param name="logCallback"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_log_set(LLamaLogCallback logCallback); | |||||
| } | |||||
| } | } | ||||
| @@ -1,15 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama.Native | |||||
| { | |||||
| internal class NativeInfo | |||||
| { | |||||
| internal static readonly int LLAMA_FILE_VERSION = 1; | |||||
| internal static readonly string LLAMA_FILE_MAGIC = "ggjt"; | |||||
| internal static readonly string LLAMA_FILE_MAGIC_UNVERSIONED = "ggml"; | |||||
| internal static readonly string LLAMA_SESSION_MAGIC = "ggsn"; | |||||
| internal static readonly int LLAMA_SESSION_VERSION = 1; | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -8,7 +9,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// A safe wrapper around a llama_context | /// A safe wrapper around a llama_context | ||||
| /// </summary> | /// </summary> | ||||
| public class SafeLLamaContextHandle | |||||
| public sealed class SafeLLamaContextHandle | |||||
| : SafeLLamaHandleBase | : SafeLLamaHandleBase | ||||
| { | { | ||||
| #region properties and fields | #region properties and fields | ||||
| @@ -25,11 +26,13 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Dimension of embedding vectors | /// Dimension of embedding vectors | ||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; | |||||
| public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; | |||||
| /// <summary> | /// <summary> | ||||
| /// This field guarantees that a reference to the model is held for as long as this handle is held | |||||
| /// Get the model which this context is using | |||||
| /// </summary> | /// </summary> | ||||
| public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); | |||||
| private SafeLlamaModelHandle? _model; | private SafeLlamaModelHandle? _model; | ||||
| #endregion | #endregion | ||||
| @@ -55,7 +58,7 @@ namespace LLama.Native | |||||
| { | { | ||||
| // Decrement refcount on model | // Decrement refcount on model | ||||
| _model?.DangerousRelease(); | _model?.DangerousRelease(); | ||||
| _model = null; | |||||
| _model = null!; | |||||
| NativeApi.llama_free(handle); | NativeApi.llama_free(handle); | ||||
| SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
| @@ -69,7 +72,7 @@ namespace LLama.Native | |||||
| if (_model == null || _model.IsClosed) | if (_model == null || _model.IsClosed) | ||||
| throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); | throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); | ||||
| return _model; | |||||
| return _model!; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -87,6 +90,35 @@ namespace LLama.Native | |||||
| return new(ctx_ptr, model); | return new(ctx_ptr, model); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a new llama context with a clone of the current llama context state | |||||
| /// </summary> | |||||
| /// <param name="lparams"></param> | |||||
| /// <returns></returns> | |||||
| public SafeLLamaContextHandle Clone(LLamaContextParams lparams) | |||||
| { | |||||
| // Allocate space to read the state of the current context | |||||
| var stateSize = GetStateSize(); | |||||
| var stateMemory = Marshal.AllocHGlobal((nint)stateSize); | |||||
| try | |||||
| { | |||||
| // Copy state from this context into memory | |||||
| GetState(stateMemory, stateSize); | |||||
| // Create a new context | |||||
| var newCtx = Create(ModelHandle, lparams); | |||||
| // Copy state into new context | |||||
| newCtx.SetState(stateMemory); | |||||
| return newCtx; | |||||
| } | |||||
| finally | |||||
| { | |||||
| Marshal.FreeHGlobal(stateMemory); | |||||
| } | |||||
| } | |||||
| #endregion | #endregion | ||||
| /// <summary> | /// <summary> | ||||
| @@ -136,7 +168,6 @@ namespace LLama.Native | |||||
| /// Rows: n_tokens<br /> | /// Rows: n_tokens<br /> | ||||
| /// Cols: n_vocab | /// Cols: n_vocab | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Span<float> GetLogits() | public Span<float> GetLogits() | ||||
| { | { | ||||
| @@ -152,7 +183,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Convert a token into a string | /// Convert a token into a string | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="token"></param> | |||||
| /// <param name="token">Token to decode into a string</param> | |||||
| /// <param name="encoding"></param> | /// <param name="encoding"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public string TokenToString(int token, Encoding encoding) | public string TokenToString(int token, Encoding encoding) | ||||
| @@ -161,13 +192,25 @@ namespace LLama.Native | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Convert a token into a span of bytes that could be decoded into a string | |||||
| /// Append a single llama token to a string builder | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="token"></param> | |||||
| /// <returns></returns> | |||||
| public ReadOnlySpan<byte> TokenToSpan(int token) | |||||
| /// <param name="token">Token to decode</param> | |||||
| /// <param name="encoding"></param> | |||||
| /// <param name="dest">string builder to append the result to</param> | |||||
| public void TokenToString(int token, Encoding encoding, StringBuilder dest) | |||||
| { | { | ||||
| return ThrowIfDisposed().TokenToSpan(token); | |||||
| ThrowIfDisposed().TokenToString(token, encoding, dest); | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert a single llama token into bytes | |||||
| /// </summary> | |||||
| /// <param name="token">Token to decode</param> | |||||
| /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | |||||
| /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | |||||
| public int TokenToSpan(int token, Span<byte> dest) | |||||
| { | |||||
| return ThrowIfDisposed().TokenToSpan(token, dest); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -177,13 +220,79 @@ namespace LLama.Native | |||||
| /// <param name="n_past">the number of tokens to use from previous eval calls</param> | /// <param name="n_past">the number of tokens to use from previous eval calls</param> | ||||
| /// <param name="n_threads"></param> | /// <param name="n_threads"></param> | ||||
| /// <returns>Returns true on success</returns> | /// <returns>Returns true on success</returns> | ||||
| public bool Eval(Memory<int> tokens, int n_past, int n_threads) | |||||
| public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads) | |||||
| { | { | ||||
| using var pin = tokens.Pin(); | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; | |||||
| fixed (int* pinned = tokens) | |||||
| { | |||||
| return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| #region state | |||||
| /// <summary> | |||||
| /// Get the size of the state, when saved as bytes | |||||
| /// </summary> | |||||
| public ulong GetStateSize() | |||||
| { | |||||
| return NativeApi.llama_get_state_size(this); | |||||
| } | |||||
| /// <summary> | |||||
| /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. | |||||
| /// </summary> | |||||
| /// <param name="dest">Destination to write to</param> | |||||
| /// <param name="size">Number of bytes available to write to in dest (check required size with `GetStateSize()`)</param> | |||||
| /// <returns>The number of bytes written to dest</returns> | |||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if dest is too small</exception> | |||||
| public unsafe ulong GetState(byte* dest, ulong size) | |||||
| { | |||||
| return GetState(new IntPtr(dest), size); | |||||
| } | |||||
| /// <summary> | |||||
| /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. | |||||
| /// </summary> | |||||
| /// <param name="dest">Destination to write to</param> | |||||
| /// <param name="size">Number of bytes available to write to in dest (check required size with `GetStateSize()`)</param> | |||||
| /// <returns>The number of bytes written to dest</returns> | |||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if dest is too small</exception> | |||||
| public ulong GetState(IntPtr dest, ulong size) | |||||
| { | |||||
| var required = GetStateSize(); | |||||
| if (size < required) | |||||
| throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); | |||||
| unsafe | |||||
| { | |||||
| return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Set the raw state of this context | |||||
| /// </summary> | |||||
| /// <param name="src">The pointer to read the state from</param> | |||||
| /// <returns>Number of bytes read from the src pointer</returns> | |||||
| public unsafe ulong SetState(byte* src) | |||||
| { | |||||
| return SetState(new IntPtr(src)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Set the raw state of this context | |||||
| /// </summary> | |||||
| /// <param name="src">The pointer to read the state from</param> | |||||
| /// <returns>Number of bytes read from the src pointer</returns> | |||||
| public ulong SetState(IntPtr src) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); | |||||
| } | |||||
| } | |||||
| #endregion | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,106 @@ | |||||
| using System; | |||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Grammars; | |||||
| namespace LLama.Native | |||||
| { | |||||
| /// <summary> | |||||
| /// A safe reference to a `llama_grammar` | |||||
| /// </summary> | |||||
| public class SafeLLamaGrammarHandle | |||||
| : SafeLLamaHandleBase | |||||
| { | |||||
| #region construction/destruction | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="handle"></param> | |||||
| internal SafeLLamaGrammarHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| NativeApi.llama_grammar_free(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a new llama_grammar | |||||
| /// </summary> | |||||
| /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param> | |||||
| /// <param name="start_rule_index">The index (in the outer list) of the start rule</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| var totalElements = rules.Sum(a => a.Elements.Count); | |||||
| var nrules = (ulong)rules.Count; | |||||
| // Borrow an array large enough to hold every single element | |||||
| // and another array large enough to hold a pointer to each rule | |||||
| var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements); | |||||
| var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count); | |||||
| try | |||||
| { | |||||
| fixed (LLamaGrammarElement* allElementsPtr = allElements) | |||||
| { | |||||
| var elementIndex = 0; | |||||
| var pointerIndex = 0; | |||||
| foreach (var rule in rules) | |||||
| { | |||||
| // Save a pointer to the start of this rule | |||||
| pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); | |||||
| // Copy all of the rule elements into the flat array | |||||
| foreach (var element in rule.Elements) | |||||
| allElementsPtr[elementIndex++] = element; | |||||
| } | |||||
| // Sanity check some things that should be true if the copy worked as planned | |||||
| Debug.Assert((ulong)pointerIndex == nrules); | |||||
| Debug.Assert(elementIndex == totalElements); | |||||
| // Make the actual call through to llama.cpp | |||||
| fixed (void* ptr = pointers) | |||||
| { | |||||
| return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); | |||||
| } | |||||
| } | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<LLamaGrammarElement>.Shared.Return(allElements); | |||||
| ArrayPool<IntPtr>.Shared.Return(pointers); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a new llama_grammar | |||||
| /// </summary> | |||||
| /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param> | |||||
| /// <param name="nrules">total number of rules</param> | |||||
| /// <param name="start_rule_index">index of the start rule of the grammar</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) | |||||
| { | |||||
| var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); | |||||
| if (grammar_ptr == IntPtr.Zero) | |||||
| throw new RuntimeError("Failed to create grammar from rules"); | |||||
| return new(grammar_ptr); | |||||
| } | |||||
| #endregion | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -7,7 +8,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// A reference to a set of llama model weights | /// A reference to a set of llama model weights | ||||
| /// </summary> | /// </summary> | ||||
| public class SafeLlamaModelHandle | |||||
| public sealed class SafeLlamaModelHandle | |||||
| : SafeLLamaHandleBase | : SafeLLamaHandleBase | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -23,14 +24,14 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Dimension of embedding vectors | /// Dimension of embedding vectors | ||||
| /// </summary> | /// </summary> | ||||
| public int EmbeddingCount { get; } | |||||
| public int EmbeddingSize { get; } | |||||
| internal SafeLlamaModelHandle(IntPtr handle) | internal SafeLlamaModelHandle(IntPtr handle) | ||||
| : base(handle) | : base(handle) | ||||
| { | { | ||||
| VocabCount = NativeApi.llama_n_vocab_from_model(this); | |||||
| ContextSize = NativeApi.llama_n_ctx_from_model(this); | |||||
| EmbeddingCount = NativeApi.llama_n_embd_from_model(this); | |||||
| VocabCount = NativeApi.llama_model_n_vocab(this); | |||||
| ContextSize = NativeApi.llama_model_n_ctx(this); | |||||
| EmbeddingSize = NativeApi.llama_model_n_embd(this); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -82,17 +83,20 @@ namespace LLama.Native | |||||
| #region tokenize | #region tokenize | ||||
| /// <summary> | /// <summary> | ||||
| /// Convert a single llama token into string bytes | |||||
| /// Convert a single llama token into bytes | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="llama_token"></param> | |||||
| /// <returns></returns> | |||||
| public ReadOnlySpan<byte> TokenToSpan(int llama_token) | |||||
| /// <param name="llama_token">Token to decode</param> | |||||
| /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | |||||
| /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | |||||
| public int TokenToSpan(int llama_token, Span<byte> dest) | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue); | |||||
| var terminator = bytes.IndexOf((byte)0); | |||||
| return bytes.Slice(0, terminator); | |||||
| fixed (byte* destPtr = dest) | |||||
| { | |||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length); | |||||
| return Math.Abs(length); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -104,16 +108,54 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public string TokenToString(int llama_token, Encoding encoding) | public string TokenToString(int llama_token, Encoding encoding) | ||||
| { | { | ||||
| var span = TokenToSpan(llama_token); | |||||
| unsafe | |||||
| { | |||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); | |||||
| if (length == 0) | |||||
| return ""; | |||||
| if (span.Length == 0) | |||||
| return ""; | |||||
| Span<byte> bytes = stackalloc byte[-length]; | |||||
| fixed (byte* bytePtr = bytes) | |||||
| { | |||||
| var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); | |||||
| Debug.Assert(written == bytes.Length); | |||||
| return encoding.GetString(bytePtr, bytes.Length); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Append a single llama token to a string builder | |||||
| /// </summary> | |||||
| /// <param name="llama_token">Token to decode</param> | |||||
| /// <param name="encoding"></param> | |||||
| /// <param name="dest">string builder to append the result to</param> | |||||
| public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest) | |||||
| { | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| fixed (byte* ptr = &span[0]) | |||||
| var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); | |||||
| if (length == 0) | |||||
| return; | |||||
| Span<byte> bytes = stackalloc byte[-length]; | |||||
| fixed (byte* bytePtr = bytes) | |||||
| { | { | ||||
| return encoding.GetString(ptr, span.Length); | |||||
| // Decode into bytes | |||||
| var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); | |||||
| Debug.Assert(written == bytes.Length); | |||||
| // Decode into chars | |||||
| var charCount = encoding.GetCharCount(bytePtr, bytes.Length); | |||||
| Span<char> chars = stackalloc char[charCount]; | |||||
| fixed (char* charPtr = chars) | |||||
| encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length); | |||||
| // Write it to the output | |||||
| for (var i = 0; i < chars.Length; i++) | |||||
| dest.Append(chars[i]); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -150,12 +192,24 @@ namespace LLama.Native | |||||
| var tokens = new int[count]; | var tokens = new int[count]; | ||||
| fixed (int* tokensPtr = &tokens[0]) | fixed (int* tokensPtr = &tokens[0]) | ||||
| { | { | ||||
| count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); | |||||
| NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); | |||||
| return tokens; | return tokens; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| #endregion | #endregion | ||||
| #region context | |||||
| /// <summary> | |||||
| /// Create a new context for this model | |||||
| /// </summary> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | |||||
| public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) | |||||
| { | |||||
| return SafeLLamaContextHandle.Create(this, @params); | |||||
| } | |||||
| #endregion | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,28 @@ | |||||
| using System; | using System; | ||||
| #pragma warning disable IDE1006 // Naming Styles | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| /// <summary> | |||||
| /// Direct translation of the llama.cpp sampling API | |||||
| /// </summary> | |||||
| public unsafe class SamplingApi | public unsafe class SamplingApi | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Apply grammar rules to candidate tokens | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <param name="grammar"></param> | |||||
| public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | |||||
| { | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -13,10 +31,25 @@ namespace LLama.Native | |||||
| /// <param name="last_tokens"></param> | /// <param name="last_tokens"></param> | ||||
| /// <param name="last_tokens_size"></param> | /// <param name="last_tokens_size"></param> | ||||
| /// <param name="penalty"></param> | /// <param name="penalty"></param> | ||||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) | |||||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty) | |||||
| { | |||||
| llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); | |||||
| } | |||||
| /// <summary> | |||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="penalty"></param> | |||||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty) | |||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | ||||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty); | |||||
| using var last_tokens_handle = last_tokens.Pin(); | |||||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -28,10 +61,26 @@ namespace LLama.Native | |||||
| /// <param name="last_tokens_size"></param> | /// <param name="last_tokens_size"></param> | ||||
| /// <param name="alpha_frequency"></param> | /// <param name="alpha_frequency"></param> | ||||
| /// <param name="alpha_presence"></param> | /// <param name="alpha_presence"></param> | ||||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||||
| { | |||||
| llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); | |||||
| } | |||||
| /// <summary> | |||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="alpha_frequency"></param> | |||||
| /// <param name="alpha_presence"></param> | |||||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence) | |||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | ||||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence); | |||||
| using var last_tokens_handle = last_tokens.Pin(); | |||||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -97,6 +146,13 @@ namespace LLama.Native | |||||
| NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); | NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Sample with temperature. | |||||
| /// As temperature increases, the prediction becomes diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <param name="temp"></param> | |||||
| public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | ||||
| @@ -116,10 +172,7 @@ namespace LLama.Native | |||||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | ||||
| fixed(float* pmu = &mu) | |||||
| { | |||||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu); | |||||
| } | |||||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -134,10 +187,7 @@ namespace LLama.Native | |||||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | ||||
| fixed (float* pmu = &mu) | |||||
| { | |||||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu); | |||||
| } | |||||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,10 +1,13 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Text; | |||||
| #pragma warning disable | |||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class ChatSession<T> where T : IChatModel | public class ChatSession<T> where T : IChatModel | ||||
| { | { | ||||
| IChatModel _model; | IChatModel _model; | ||||
| @@ -1,9 +1,12 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | |||||
| #pragma warning disable | |||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public interface IChatModel | public interface IChatModel | ||||
| { | { | ||||
| string Name { get; } | string Name { get; } | ||||
| @@ -1,12 +1,15 @@ | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| #pragma warning disable | |||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| public class LLamaEmbedder : IDisposable | |||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class LLamaEmbedder | |||||
| : IDisposable | |||||
| { | { | ||||
| SafeLLamaContextHandle _ctx; | SafeLLamaContextHandle _ctx; | ||||
| @@ -9,10 +9,16 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| #pragma warning disable | |||||
| // ReSharper disable all | |||||
| namespace LLama.OldVersion | namespace LLama.OldVersion | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public class LLamaModel : IChatModel, IDisposable | |||||
| [Obsolete("The entire LLama.OldVersion namespace will be removed")] | |||||
| public class LLamaModel | |||||
| : IChatModel, IDisposable | |||||
| { | { | ||||
| LLamaParams _params; | LLamaParams _params; | ||||
| SafeLLamaContextHandle _ctx; | SafeLLamaContextHandle _ctx; | ||||
| @@ -27,7 +33,6 @@ namespace LLama.OldVersion | |||||
| bool _is_interacting; | bool _is_interacting; | ||||
| bool _is_antiprompt; | bool _is_antiprompt; | ||||
| bool _input_echo; | bool _input_echo; | ||||
| bool _verbose; | |||||
| // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | ||||
| // if we loaded a session with at least 75% similarity. It's currently just used to speed up the | // if we loaded a session with at least 75% similarity. It's currently just used to speed up the | ||||
| @@ -40,17 +45,8 @@ namespace LLama.OldVersion | |||||
| List<llama_token> _embed; | List<llama_token> _embed; | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public bool Verbose | |||||
| { | |||||
| get | |||||
| { | |||||
| return _verbose; | |||||
| } | |||||
| set | |||||
| { | |||||
| _verbose = value; | |||||
| } | |||||
| } | |||||
| public bool Verbose { get; set; } | |||||
| public SafeLLamaContextHandle NativeHandle => _ctx; | public SafeLLamaContextHandle NativeHandle => _ctx; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -173,7 +169,7 @@ namespace LLama.OldVersion | |||||
| { | { | ||||
| Name = name; | Name = name; | ||||
| _params = @params; | _params = @params; | ||||
| _verbose = verbose; | |||||
| Verbose = verbose; | |||||
| _ctx = Utils.llama_init_from_gpt_params(ref _params); | _ctx = Utils.llama_init_from_gpt_params(ref _params); | ||||
| // Add a space in front of the first character to match OG llama tokenizer behavior | // Add a space in front of the first character to match OG llama tokenizer behavior | ||||
| @@ -509,7 +505,7 @@ namespace LLama.OldVersion | |||||
| } | } | ||||
| if (_is_interacting) | if (_is_interacting) | ||||
| { | { | ||||
| if (_verbose) | |||||
| if (Verbose) | |||||
| { | { | ||||
| LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it."); | LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it."); | ||||
| } | } | ||||
| @@ -620,7 +616,7 @@ namespace LLama.OldVersion | |||||
| NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); | NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); | ||||
| } | } | ||||
| llama_token id = 0; | |||||
| llama_token id; | |||||
| { | { | ||||
| var n_vocab = NativeApi.llama_n_vocab(_ctx); | var n_vocab = NativeApi.llama_n_vocab(_ctx); | ||||
| @@ -638,7 +634,7 @@ namespace LLama.OldVersion | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); | LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); | ||||
| // Apply penalties | // Apply penalties | ||||
| float nl_logit = logits[NativeApi.llama_token_nl()]; | |||||
| float nl_logit = logits[NativeApi.llama_token_nl(_ctx)]; | |||||
| var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx); | var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx); | ||||
| SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, | SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, | ||||
| _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(), | _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(), | ||||
| @@ -648,7 +644,7 @@ namespace LLama.OldVersion | |||||
| (ulong)last_n_repeat, alpha_frequency, alpha_presence); | (ulong)last_n_repeat, alpha_frequency, alpha_presence); | ||||
| if (!penalize_nl) | if (!penalize_nl) | ||||
| { | { | ||||
| logits[NativeApi.llama_token_nl()] = nl_logit; | |||||
| logits[NativeApi.llama_token_nl(_ctx)] = nl_logit; | |||||
| } | } | ||||
| if (temp <= 0) | if (temp <= 0) | ||||
| @@ -688,7 +684,7 @@ namespace LLama.OldVersion | |||||
| } | } | ||||
| // replace end of text token with newline token when in interactive mode | // replace end of text token with newline token when in interactive mode | ||||
| if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct) | |||||
| if (id == NativeApi.llama_token_eos(_ctx) && _params.interactive && !_params.instruct) | |||||
| { | { | ||||
| id = _llama_token_newline[0]; | id = _llama_token_newline[0]; | ||||
| if (_params.antiprompt.Count != 0) | if (_params.antiprompt.Count != 0) | ||||
| @@ -764,7 +760,7 @@ namespace LLama.OldVersion | |||||
| break; | break; | ||||
| } | } | ||||
| if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos()) | |||||
| if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos(_ctx)) | |||||
| { | { | ||||
| if (_params.instruct) | if (_params.instruct) | ||||
| { | { | ||||