| @@ -140,7 +140,7 @@ jobs: | |||
| - build: 'arm64' | |||
| defines: '-DCMAKE_OSX_ARCHITECTURES=arm64' | |||
| - build: 'x64' | |||
| defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF' | |||
| defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF -DLLAMA_AVX=ON -DLLAMA_AVX2=ON' | |||
| runs-on: macos-latest | |||
| steps: | |||
| - uses: actions/checkout@v3 | |||
| @@ -0,0 +1,24 @@ | |||
| { | |||
| "messages": [ | |||
| { | |||
| "author_role": "System", | |||
| "content": "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision." | |||
| }, | |||
| { | |||
| "author_role": "User", | |||
| "content": "Hello, Bob." | |||
| }, | |||
| { | |||
| "author_role": "Assistant", | |||
| "content": "Hello. How may I help you today?" | |||
| }, | |||
| { | |||
| "author_role": "User", | |||
| "content": "Please tell me the largest city in Europe." | |||
| }, | |||
| { | |||
| "author_role": "Assistant", | |||
| "content": "Sure. The largest city in Europe is Istanbul, Turkey." | |||
| } | |||
| ] | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| { | |||
| "messages": [ | |||
| { | |||
| "author_role": "System", | |||
| "content": "下面是一段你和用户的对话,你叫坤坤,是一个在各方面都拥有丰富经验的助理,你非常乐于回答用户的问题和帮助用户。" | |||
| }, | |||
| { | |||
| "author_role": "User", | |||
| "content": "你好,坤坤。" | |||
| }, | |||
| { | |||
| "author_role": "Assistant", | |||
| "content": "你好,有什么我能帮助你的吗?" | |||
| }, | |||
| { | |||
| "author_role": "User", | |||
| "content": "中国的首都是哪座城市?" | |||
| }, | |||
| { | |||
| "author_role": "Assistant", | |||
| "content": "中国的首都是北京市。" | |||
| } | |||
| ] | |||
| } | |||
| @@ -1,69 +1,122 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using System.Text; | |||
| using LLama.Common; | |||
| namespace LLama.Examples.Examples | |||
| namespace LLama.Examples.Examples; | |||
| public class ChatChineseGB2312 | |||
| { | |||
| public class ChatChineseGB2312 | |||
| private static string ConvertEncoding(string input, Encoding original, Encoding target) | |||
| { | |||
| byte[] bytes = original.GetBytes(input); | |||
| var convertedBytes = Encoding.Convert(original, target, bytes); | |||
| return target.GetString(convertedBytes); | |||
| } | |||
| public static async Task Run() | |||
| { | |||
| private static string ConvertFromEncodingToAnother(string input, Encoding original, Encoding target) | |||
| // Register provider for GB2312 encoding | |||
| Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" + | |||
| " to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers."); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5, | |||
| Encoding = Encoding.UTF8 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| ChatSession session; | |||
| if (Directory.Exists("Assets/chat-with-kunkun-chinese")) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Loading session from disk."); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| session = new ChatSession(executor); | |||
| session.LoadSession("Assets/chat-with-kunkun-chinese"); | |||
| } | |||
| else | |||
| { | |||
| byte[] bytes = original.GetBytes(input); | |||
| var convertedBytes = Encoding.Convert(original, target, bytes); | |||
| return target.GetString(convertedBytes); | |||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); | |||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | |||
| session = new ChatSession(executor, chatHistory); | |||
| } | |||
| public static async Task Run() | |||
| session | |||
| .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); | |||
| InferenceParams inferenceParams = new InferenceParams() | |||
| { | |||
| Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); // Register gb2312 encoding | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var prompt = File.ReadAllText("Assets/chat-with-kunkun-chinese.txt", encoding: Encoding.GetEncoding("gb2312")).Trim(); | |||
| prompt = ConvertFromEncodingToAnother(prompt, Encoding.GetEncoding("gb2312"), Encoding.UTF8); | |||
| Temperature = 0.9f, | |||
| AntiPrompts = new List<string> { "用户:" } | |||
| }; | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 20, | |||
| Encoding = Encoding.UTF8 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started."); | |||
| var session = new ChatSession(executor).WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户")); | |||
| // show the prompt | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write("用户:"); | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string userInput = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" + | |||
| " to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers."); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| while (userInput != "exit") | |||
| { | |||
| // Convert the encoding from gb2312 to utf8 for the language model | |||
| // and later saving to the history json file. | |||
| userInput = ConvertEncoding(userInput, Encoding.GetEncoding("gb2312"), Encoding.UTF8); | |||
| // show the prompt | |||
| Console.Write(prompt); | |||
| while (true) | |||
| if (userInput == "save") | |||
| { | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() | |||
| session.SaveSession("Assets/chat-with-kunkun-chinese"); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Session saved."); | |||
| } | |||
| else if (userInput == "regenerate") | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Regenerating last response ..."); | |||
| await foreach ( | |||
| var text | |||
| in session.RegenerateAssistantMessageAsync( | |||
| inferenceParams)) | |||
| { | |||
| Temperature = 0.3f, | |||
| TopK = 5, | |||
| TopP = 0.85f, | |||
| AntiPrompts = new List<string> { "用户:" }, | |||
| MaxTokens = 2048, | |||
| RepeatPenalty = 1.05f | |||
| })) | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| // Convert the encoding from utf8 to gb2312 for the console output. | |||
| Console.Write(ConvertEncoding(text, Encoding.UTF8, Encoding.GetEncoding("gb2312"))); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| await foreach ( | |||
| var text | |||
| in session.ChatAsync( | |||
| new ChatHistory.Message(AuthorRole.User, userInput), | |||
| inferenceParams)) | |||
| { | |||
| //Console.Write(text); | |||
| Console.Write(ConvertFromEncodingToAnother(text, Encoding.UTF8, Encoding.GetEncoding("gb2312"))); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write(text); | |||
| } | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| prompt = Console.ReadLine(); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| } | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| userInput = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,44 +1,61 @@ | |||
| using LLama.Common; | |||
| namespace LLama.Examples.Examples | |||
| namespace LLama.Examples.Examples; | |||
| public class ChatSessionStripRoleName | |||
| { | |||
| public class ChatSessionStripRoleName | |||
| public static async Task Run() | |||
| { | |||
| public static async Task Run() | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started. The role names won't be printed."); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | |||
| // show the prompt | |||
| Console.Write(prompt); | |||
| while (true) | |||
| { | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| ChatSession session = new(executor, chatHistory); | |||
| session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( | |||
| new string[] { "User:", "Assistant:" }, | |||
| redundancyLength: 8)); | |||
| InferenceParams inferenceParams = new InferenceParams() | |||
| { | |||
| Temperature = 0.9f, | |||
| AntiPrompts = new List<string> { "User:" } | |||
| }; | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| prompt = Console.ReadLine(); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started."); | |||
| // show the prompt | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string userInput = Console.ReadLine() ?? ""; | |||
| while (userInput != "exit") | |||
| { | |||
| await foreach ( | |||
| var text | |||
| in session.ChatAsync( | |||
| new ChatHistory.Message(AuthorRole.User, userInput), | |||
| inferenceParams)) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write(text); | |||
| } | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| userInput = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,98 @@ | |||
| using LLama.Common; | |||
| namespace LLama.Examples.Examples; | |||
| public class ChatSessionWithHistory | |||
| { | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| ChatSession session; | |||
| if (Directory.Exists("Assets/chat-with-bob")) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Loading session from disk."); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| session = new ChatSession(executor); | |||
| session.LoadSession("Assets/chat-with-bob"); | |||
| } | |||
| else | |||
| { | |||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | |||
| session = new ChatSession(executor, chatHistory); | |||
| } | |||
| session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( | |||
| new string[] { "User:", "Assistant:" }, | |||
| redundancyLength: 8)); | |||
| InferenceParams inferenceParams = new InferenceParams() | |||
| { | |||
| Temperature = 0.9f, | |||
| AntiPrompts = new List<string> { "User:" } | |||
| }; | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started."); | |||
| // show the prompt | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string userInput = Console.ReadLine() ?? ""; | |||
| while (userInput != "exit") | |||
| { | |||
| if (userInput == "save") | |||
| { | |||
| session.SaveSession("Assets/chat-with-bob"); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Session saved."); | |||
| } | |||
| else if (userInput == "regenerate") | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Regenerating last response ..."); | |||
| await foreach ( | |||
| var text | |||
| in session.RegenerateAssistantMessageAsync( | |||
| inferenceParams)) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write(text); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| await foreach ( | |||
| var text | |||
| in session.ChatAsync( | |||
| new ChatHistory.Message(AuthorRole.User, userInput), | |||
| inferenceParams)) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write(text); | |||
| } | |||
| } | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| userInput = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,44 +1,58 @@ | |||
| using LLama.Common; | |||
| namespace LLama.Examples.Examples | |||
| namespace LLama.Examples.Examples; | |||
| public class ChatSessionWithRoleName | |||
| { | |||
| public class ChatSessionWithRoleName | |||
| public static async Task Run() | |||
| { | |||
| public static async Task Run() | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| var session = new ChatSession(executor); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | |||
| // show the prompt | |||
| Console.Write(prompt); | |||
| while (true) | |||
| { | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| ChatSession session = new(executor, chatHistory); | |||
| InferenceParams inferenceParams = new InferenceParams() | |||
| { | |||
| Temperature = 0.9f, | |||
| AntiPrompts = new List<string> { "User:" } | |||
| }; | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| prompt = Console.ReadLine(); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started."); | |||
| // show the prompt | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string userInput = Console.ReadLine() ?? ""; | |||
| while (userInput != "exit") | |||
| { | |||
| await foreach ( | |||
| var text | |||
| in session.ChatAsync( | |||
| new ChatHistory.Message(AuthorRole.User, userInput), | |||
| inferenceParams)) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write(text); | |||
| } | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| userInput = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Common; | |||
| using DocumentFormat.OpenXml.Bibliography; | |||
| using LLama.Common; | |||
| namespace LLama.Examples.Examples | |||
| { | |||
| @@ -30,7 +31,15 @@ namespace LLama.Examples.Examples | |||
| Console.Write(prompt); | |||
| while (true) | |||
| { | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| await foreach ( | |||
| var text | |||
| in session.ChatAsync( | |||
| new ChatHistory.Message(AuthorRole.User, prompt), | |||
| new InferenceParams() | |||
| { | |||
| Temperature = 0.6f, | |||
| AntiPrompts = new List<string> { "User:" } | |||
| })) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -6,8 +6,10 @@ public class Runner | |||
| { | |||
| private static readonly Dictionary<string, Func<Task>> Examples = new() | |||
| { | |||
| { "Run a chat session with history.", ChatSessionWithHistory.Run }, | |||
| { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, | |||
| { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, | |||
| { "Run a chat session in Chinese GB2312 encoding", ChatChineseGB2312.Run }, | |||
| { "Interactive mode chat by using executor.", InteractiveModeExecute.Run }, | |||
| { "Instruct mode chat by using executor.", InstructModeExecute.Run }, | |||
| { "Stateless mode chat by using executor.", StatelessModeExecute.Run }, | |||
| @@ -23,7 +25,6 @@ public class Runner | |||
| { "Coding Assistant.", CodingAssistant.Run }, | |||
| { "Batch Decoding.", BatchedDecoding.Run }, | |||
| { "SK Kernel Memory.", KernelMemory.Run }, | |||
| { "Chinese gb2312 chat", ChatChineseGB2312.Run }, | |||
| { "Exit", async () => Environment.Exit(0) } | |||
| }; | |||
| @@ -2,7 +2,7 @@ | |||
| <Import Project="..\LLama\LLamaSharp.Runtime.targets" /> | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>net6.0</TargetFramework> | |||
| <TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks> | |||
| <ImplicitUsings>enable</ImplicitUsings> | |||
| <Nullable>enable</Nullable> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| @@ -27,6 +27,12 @@ | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <None Update="Assets\chat-with-bob.json"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\chat-with-kunkun-chinese.json"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\chat-with-bob.txt"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| @@ -1,7 +1,7 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFrameworks>net6.0;net7.0</TargetFrameworks> | |||
| <TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks> | |||
| <ImplicitUsings>enable</ImplicitUsings> | |||
| <Nullable>enable</Nullable> | |||
| <Version>0.8.0</Version> | |||
| @@ -1,5 +1,4 @@ | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using LLama.Grammars; | |||
| @@ -1,7 +1,7 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <Import Project="..\LLama\LLamaSharp.Runtime.targets" /> | |||
| <PropertyGroup> | |||
| <TargetFramework>net6.0</TargetFramework> | |||
| <TargetFramework>net8.0</TargetFramework> | |||
| <RootNamespace>LLama.Unittest</RootNamespace> | |||
| <ImplicitUsings>enable</ImplicitUsings> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| @@ -15,8 +15,8 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0" /> | |||
| <PackageReference Include="System.Linq.Async" Version="6.0.1" /> | |||
| <PackageReference Include="xunit" Version="2.6.2" /> | |||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.4"> | |||
| <PackageReference Include="xunit" Version="2.6.3" /> | |||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.5"> | |||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |||
| <PrivateAssets>all</PrivateAssets> | |||
| </PackageReference> | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Common; | |||
| using System.Text.Json; | |||
| namespace LLama.Unittest | |||
| { | |||
| @@ -16,14 +17,19 @@ namespace LLama.Unittest | |||
| TensorSplits = { [0] = 3 } | |||
| }; | |||
| var json = System.Text.Json.JsonSerializer.Serialize(expected); | |||
| var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json)!; | |||
| var json = JsonSerializer.Serialize(expected); | |||
| var actual = JsonSerializer.Deserialize<ModelParams>(json)!; | |||
| // Cannot compare splits with default equality, check they are sequence equal and then set to null | |||
| Assert.Equal((IEnumerable<float>)expected.TensorSplits, expected.TensorSplits); | |||
| actual.TensorSplits = null!; | |||
| expected.TensorSplits = null!; | |||
| // Check encoding is the same | |||
| var b1 = expected.Encoding.GetBytes("Hello"); | |||
| var b2 = actual.Encoding.GetBytes("Hello"); | |||
| Assert.True(b1.SequenceEqual(b2)); | |||
| Assert.Equal(expected, actual); | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System.Diagnostics; | |||
| using LLama.Common; | |||
| using LLama.Sampling; | |||
| using Xunit.Abstractions; | |||
| namespace LLama.Unittest | |||
| @@ -30,10 +31,13 @@ namespace LLama.Unittest | |||
| [Fact] | |||
| public async Task Stateless() | |||
| { | |||
| // Create a custom pipeline that mimics the default pipeline | |||
| var pipeline = new DefaultSamplingPipeline(); | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| const string question = "Question. what is a cat?\nAnswer: "; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; | |||
| var timer = new Stopwatch(); | |||
| timer.Start(); | |||
| @@ -1,6 +1,9 @@ | |||
| using LLama.Common; | |||
| #nullable enable | |||
| using LLama.Common; | |||
| using LLama.Abstractions; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| namespace LLama.Web.Common | |||
| { | |||
| @@ -64,6 +67,9 @@ namespace LLama.Web.Common | |||
| /// <summary> | |||
| /// A grammar to constrain possible tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <inheritdoc /> | |||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||
| } | |||
| } | |||
| @@ -11,8 +11,7 @@ public class StatefulChatService : IDisposable | |||
| private readonly LLamaContext _context; | |||
| private bool _continue = false; | |||
| private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" | |||
| + "User: "; | |||
| private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision."; | |||
| public StatefulChatService(IConfiguration configuration) | |||
| { | |||
| @@ -25,7 +24,9 @@ public class StatefulChatService : IDisposable | |||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||
| _context = new LLamaContext(weights, @params); | |||
| _session = new ChatSession(new InteractiveExecutor(_context)); | |||
| _session.History.AddMessage(Common.AuthorRole.System, SystemPrompt); | |||
| } | |||
| public void Dispose() | |||
| @@ -35,10 +36,8 @@ public class StatefulChatService : IDisposable | |||
| public async Task<string> Send(SendMessageInput input) | |||
| { | |||
| var userInput = input.Text; | |||
| if (!_continue) | |||
| { | |||
| userInput = SystemPrompt + userInput; | |||
| Console.Write(SystemPrompt); | |||
| _continue = true; | |||
| } | |||
| @@ -47,11 +46,14 @@ public class StatefulChatService : IDisposable | |||
| Console.Write(input.Text); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() | |||
| { | |||
| RepeatPenalty = 1.0f, | |||
| AntiPrompts = new string[] { "User:" }, | |||
| }); | |||
| var outputs = _session.ChatAsync( | |||
| new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), | |||
| new Common.InferenceParams() | |||
| { | |||
| RepeatPenalty = 1.0f, | |||
| AntiPrompts = new string[] { "User:" }, | |||
| }); | |||
| var result = ""; | |||
| await foreach (var output in outputs) | |||
| { | |||
| @@ -64,10 +66,8 @@ public class StatefulChatService : IDisposable | |||
| public async IAsyncEnumerable<string> SendStream(SendMessageInput input) | |||
| { | |||
| var userInput = input.Text; | |||
| if (!_continue) | |||
| { | |||
| userInput = SystemPrompt + userInput; | |||
| Console.Write(SystemPrompt); | |||
| _continue = true; | |||
| } | |||
| @@ -76,11 +76,14 @@ public class StatefulChatService : IDisposable | |||
| Console.Write(input.Text); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() | |||
| { | |||
| RepeatPenalty = 1.0f, | |||
| AntiPrompts = new string[] { "User:" }, | |||
| }); | |||
| var outputs = _session.ChatAsync( | |||
| new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text) | |||
| , new Common.InferenceParams() | |||
| { | |||
| RepeatPenalty = 1.0f, | |||
| AntiPrompts = new string[] { "User:" }, | |||
| }); | |||
| await foreach (var output in outputs) | |||
| { | |||
| Console.Write(output); | |||
| @@ -1,6 +1,7 @@ | |||
| using System.Collections.Generic; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| namespace LLama.Abstractions | |||
| { | |||
| @@ -108,5 +109,10 @@ namespace LLama.Abstractions | |||
| /// Grammar to constrain possible tokens | |||
| /// </summary> | |||
| SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | |||
| /// </summary> | |||
| ISamplingPipeline? SamplingPipeline { get; set; } | |||
| } | |||
| } | |||
| @@ -3,6 +3,9 @@ using System.Buffers; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| namespace LLama.Abstractions | |||
| @@ -105,6 +108,7 @@ namespace LLama.Abstractions | |||
| /// <summary> | |||
| /// A fixed size array to set the tensor splits across multiple GPUs | |||
| /// </summary> | |||
| [JsonConverter(typeof(TensorSplitsCollectionConverter))] | |||
| public sealed class TensorSplitsCollection | |||
| : IEnumerable<float> | |||
| { | |||
| @@ -174,4 +178,24 @@ namespace LLama.Abstractions | |||
| } | |||
| #endregion | |||
| } | |||
| /// <summary> | |||
| /// A JSON converter for <see cref="TensorSplitsCollection"/> | |||
| /// </summary> | |||
| public class TensorSplitsCollectionConverter | |||
| : JsonConverter<TensorSplitsCollection> | |||
| { | |||
| /// <inheritdoc/> | |||
| public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); | |||
| return new TensorSplitsCollection(arr); | |||
| } | |||
| /// <inheritdoc/> | |||
| public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) | |||
| { | |||
| JsonSerializer.Serialize(writer, value.Splits, options); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,246 +1,496 @@ | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using static LLama.InteractiveExecutor; | |||
| namespace LLama | |||
| namespace LLama; | |||
| /// <summary> | |||
| /// The main chat session class. | |||
| /// </summary> | |||
| public class ChatSession | |||
| { | |||
| private const string _modelStateFilename = "ModelState.st"; | |||
| private const string _executorStateFilename = "ExecutorState.json"; | |||
| private const string _hsitoryFilename = "ChatHistory.json"; | |||
| /// <summary> | |||
| /// The main chat session class. | |||
| /// </summary> | |||
| public class ChatSession | |||
| { | |||
| private readonly ILLamaExecutor _executor; | |||
| private readonly ChatHistory _history; | |||
| private const string _executorStateFilename = "ExecutorState.json"; | |||
| private const string _modelStateFilename = "ModelState.st"; | |||
| /// <summary> | |||
| /// The executor for this session. | |||
| /// </summary> | |||
| public ILLamaExecutor Executor => _executor; | |||
| /// <summary> | |||
| /// The chat history for this session. | |||
| /// </summary> | |||
| public ChatHistory History => _history; | |||
| /// <summary> | |||
| /// The history transform used in this session. | |||
| /// </summary> | |||
| public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); | |||
| /// <summary> | |||
| /// The input transform pipeline used in this session. | |||
| /// </summary> | |||
| public List<ITextTransform> InputTransformPipeline { get; set; } = new(); | |||
| /// <summary> | |||
| /// The output transform used in this session. | |||
| /// </summary> | |||
| public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="executor">The executor for this session</param> | |||
| public ChatSession(ILLamaExecutor executor) | |||
| { | |||
| _executor = executor; | |||
| _history = new ChatHistory(); | |||
| } | |||
| /// <summary> | |||
| /// Use a custom history transform. | |||
| /// </summary> | |||
| /// <param name="transform"></param> | |||
| /// <returns></returns> | |||
| public ChatSession WithHistoryTransform(IHistoryTransform transform) | |||
| { | |||
| HistoryTransform = transform; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Add a text transform to the input transform pipeline. | |||
| /// </summary> | |||
| /// <param name="transform"></param> | |||
| /// <returns></returns> | |||
| public ChatSession AddInputTransform(ITextTransform transform) | |||
| { | |||
| InputTransformPipeline.Add(transform); | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Use a custom output transform. | |||
| /// </summary> | |||
| /// <param name="transform"></param> | |||
| /// <returns></returns> | |||
| public ChatSession WithOutputTransform(ITextStreamTransform transform) | |||
| { | |||
| OutputTransform = transform; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param> | |||
| public virtual void SaveSession(string path) | |||
| { | |||
| if (!Directory.Exists(path)) | |||
| { | |||
| Directory.CreateDirectory(path); | |||
| } | |||
| _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); | |||
| if (Executor is StatelessExecutor) | |||
| { | |||
| /// The executor for this session. | |||
| /// </summary> | |||
| public ILLamaExecutor Executor { get; private set; } | |||
| } | |||
| else if (Executor is StatefulExecutorBase statefulExecutor) | |||
| { | |||
| statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); | |||
| } | |||
| else | |||
| { | |||
| throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); | |||
| } | |||
| /// <summary> | |||
| /// The chat history for this session. | |||
| /// </summary> | |||
| public ChatHistory History { get; private set; } = new(); | |||
| /// <summary> | |||
| /// The history transform used in this session. | |||
| /// </summary> | |||
| public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); | |||
| /// <summary> | |||
| /// The input transform pipeline used in this session. | |||
| /// </summary> | |||
| public List<ITextTransform> InputTransformPipeline { get; set; } = new(); | |||
| /// <summary> | |||
| /// The output transform used in this session. | |||
| /// </summary> | |||
| public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||
| /// <summary> | |||
| /// Create a new chat session. | |||
| /// </summary> | |||
| /// <param name="executor">The executor for this session</param> | |||
| public ChatSession(ILLamaExecutor executor) | |||
| { | |||
| // Check if executor has StatefulExecutorBase as base class | |||
| if (executor is not StatefulExecutorBase) | |||
| { | |||
| throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="path">The directory name to load the session.</param> | |||
| public virtual void LoadSession(string path) | |||
| Executor = executor; | |||
| } | |||
| /// <summary> | |||
| /// Create a new chat session with a custom history. | |||
| /// </summary> | |||
| /// <param name="executor"></param> | |||
| /// <param name="history"></param> | |||
| public ChatSession(ILLamaExecutor executor, ChatHistory history) | |||
| : this(executor) | |||
| { | |||
| History = history; | |||
| } | |||
| /// <summary> | |||
| /// Use a custom history transform. | |||
| /// </summary> | |||
| /// <param name="transform"></param> | |||
| /// <returns></returns> | |||
| public ChatSession WithHistoryTransform(IHistoryTransform transform) | |||
| { | |||
| HistoryTransform = transform; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Add a text transform to the input transform pipeline. | |||
| /// </summary> | |||
| /// <param name="transform"></param> | |||
| /// <returns></returns> | |||
| public ChatSession AddInputTransform(ITextTransform transform) | |||
| { | |||
| InputTransformPipeline.Add(transform); | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Use a custom output transform. | |||
| /// </summary> | |||
| /// <param name="transform"></param> | |||
| /// <returns></returns> | |||
| public ChatSession WithOutputTransform(ITextStreamTransform transform) | |||
| { | |||
| OutputTransform = transform; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Save a session from a directory. | |||
| /// </summary> | |||
| /// <param name="path"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public void SaveSession(string path) | |||
| { | |||
| if (string.IsNullOrWhiteSpace(path)) | |||
| { | |||
| if (!Directory.Exists(path)) | |||
| { | |||
| throw new FileNotFoundException($"Directory {path} does not exist."); | |||
| } | |||
| _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); | |||
| if (Executor is StatelessExecutor) | |||
| { | |||
| throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); | |||
| } | |||
| } | |||
| else if (Executor is StatefulExecutorBase statefulExecutor) | |||
| { | |||
| statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); | |||
| } | |||
| else | |||
| { | |||
| throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); | |||
| } | |||
| if (Directory.Exists(path)) | |||
| { | |||
| Directory.Delete(path, recursive: true); | |||
| } | |||
| /// <summary> | |||
| /// Generates a response for a given user prompt and manages history state for the user. | |||
| /// This will always pass the whole history to the model. Don't pass a whole history | |||
| /// to this method as the user prompt will be appended to the history of the current session. | |||
| /// If more control is needed, use the other overload of this method that accepts a ChatHistory object. | |||
| /// </summary> | |||
| /// <param name="prompt"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns>Returns generated text of the assistant message.</returns> | |||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| Directory.CreateDirectory(path); | |||
| string modelStateFilePath = Path.Combine(path, _modelStateFilename); | |||
| Executor.Context.SaveState(modelStateFilePath); | |||
| string executorStateFilepath = Path.Combine(path, _executorStateFilename); | |||
| ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); | |||
| string historyFilepath = Path.Combine(path, _hsitoryFilename); | |||
| File.WriteAllText(historyFilepath, History.ToJson()); | |||
| } | |||
| /// <summary> | |||
| /// Load a session from a directory. | |||
| /// </summary> | |||
| /// <param name="path"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public void LoadSession(string path) | |||
| { | |||
| if (string.IsNullOrWhiteSpace(path)) | |||
| { | |||
| foreach (var inputTransform in InputTransformPipeline) | |||
| prompt = inputTransform.Transform(prompt); | |||
| throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); | |||
| } | |||
| if (!Directory.Exists(path)) | |||
| { | |||
| throw new ArgumentException("Directory does not exist", nameof(path)); | |||
| } | |||
| string modelStateFilePath = Path.Combine(path, _modelStateFilename); | |||
| Executor.Context.LoadState(modelStateFilePath); | |||
| // TODO: need to be refactored. | |||
| if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun) | |||
| string executorStateFilepath = Path.Combine(path, _executorStateFilename); | |||
| ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); | |||
| string historyFilepath = Path.Combine(path, _hsitoryFilename); | |||
| string historyJson = File.ReadAllText(historyFilepath); | |||
| History = ChatHistory.FromJson(historyJson) | |||
| ?? throw new ArgumentException("History file is invalid", nameof(path)); | |||
| } | |||
| /// <summary> | |||
| /// Add a message to the chat history. | |||
| /// </summary> | |||
| /// <param name="message"></param> | |||
| /// <returns></returns> | |||
| public ChatSession AddMessage(ChatHistory.Message message) | |||
| { | |||
| // If current message is a system message, only allow the history to be empty | |||
| if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) | |||
| { | |||
| throw new ArgumentException("Cannot add a system message after another message", nameof(message)); | |||
| } | |||
| // If current message is a user message, only allow the history to be empty, | |||
| // or the previous message to be a system message or assistant message. | |||
| if (message.AuthorRole == AuthorRole.User) | |||
| { | |||
| ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); | |||
| if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) | |||
| { | |||
| History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt)); | |||
| var converted_prompt = HistoryTransform.HistoryToText(History); | |||
| // Avoid missing anti-prompt. | |||
| if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n")) | |||
| { | |||
| prompt = converted_prompt.Trim(); | |||
| } | |||
| else | |||
| { | |||
| prompt = converted_prompt; | |||
| } | |||
| throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); | |||
| } | |||
| else | |||
| } | |||
| // If the current message is an assistant message, | |||
| // the previous message must be a user message. | |||
| if (message.AuthorRole == AuthorRole.Assistant) | |||
| { | |||
| ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); | |||
| if (lastMessage is null | |||
| || lastMessage.AuthorRole != AuthorRole.User) | |||
| { | |||
| History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); | |||
| throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message)); | |||
| } | |||
| } | |||
| History.AddMessage(message.AuthorRole, message.Content); | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Add a system message to the chat history. | |||
| /// </summary> | |||
| /// <param name="content"></param> | |||
| /// <returns></returns> | |||
| public ChatSession AddSystemMessage(string content) | |||
| => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); | |||
| /// <summary> | |||
| /// Add an assistant message to the chat history. | |||
| /// </summary> | |||
| /// <param name="content"></param> | |||
| /// <returns></returns> | |||
| public ChatSession AddAssistantMessage(string content) | |||
| => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); | |||
| /// <summary> | |||
| /// Add a user message to the chat history. | |||
| /// </summary> | |||
| /// <param name="content"></param> | |||
| /// <returns></returns> | |||
| public ChatSession AddUserMessage(string content) | |||
| => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); | |||
| StringBuilder sb = new(); | |||
| /// <summary> | |||
| /// Remove the last message from the chat history. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public ChatSession RemoveLastMessage() | |||
| { | |||
| History.Messages.RemoveAt(History.Messages.Count - 1); | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Replace a user message with a new message and remove all messages after the new message. | |||
| /// This is useful when the user wants to edit a message. And regenerate the response. | |||
| /// </summary> | |||
| /// <param name="oldMessage"></param> | |||
| /// <param name="newMessage"></param> | |||
| /// <returns></returns> | |||
| public ChatSession ReplaceUserMessage( | |||
| ChatHistory.Message oldMessage, | |||
| ChatHistory.Message newMessage) | |||
| { | |||
| if (oldMessage.AuthorRole != AuthorRole.User) | |||
| { | |||
| throw new ArgumentException("Old message must be a user message", nameof(oldMessage)); | |||
| } | |||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||
| if (newMessage.AuthorRole != AuthorRole.User) | |||
| { | |||
| throw new ArgumentException("New message must be a user message", nameof(newMessage)); | |||
| } | |||
| int index = History.Messages.IndexOf(oldMessage); | |||
| if (index == -1) | |||
| { | |||
| throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); | |||
| } | |||
| History.Messages[index] = newMessage; | |||
| // Remove all message after the new message | |||
| History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Chat with the model. | |||
| /// </summary> | |||
| /// <param name="message"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="applyInputTransformPipeline"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public async IAsyncEnumerable<string> ChatAsync( | |||
| ChatHistory.Message message, | |||
| bool applyInputTransformPipeline, | |||
| IInferenceParams? inferenceParams = null, | |||
| [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| // The message must be a user message | |||
| if (message.AuthorRole != AuthorRole.User) | |||
| { | |||
| throw new ArgumentException("Message must be a user message", nameof(message)); | |||
| } | |||
| // Apply input transform pipeline | |||
| if (applyInputTransformPipeline) | |||
| { | |||
| foreach (var inputTransform in InputTransformPipeline) | |||
| { | |||
| yield return result; | |||
| sb.Append(result); | |||
| message.Content = inputTransform.Transform(message.Content); | |||
| } | |||
| } | |||
| // Add the user's message to the history | |||
| AddUserMessage(message.Content); | |||
| // Prepare prompt variable | |||
| string prompt; | |||
| // Check if the session history was restored from a previous session | |||
| // or added as part of new chat session history. | |||
| InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData(); | |||
| // If "IsPromptRun" is true, the session was newly started. | |||
| if (state.IsPromptRun) | |||
| { | |||
| // If the session history was added as part of new chat session history, | |||
| // convert the complete history includsing system message and manually added history | |||
| // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. | |||
| prompt = HistoryTransform.HistoryToText(History); | |||
| } | |||
| else | |||
| { | |||
| // If the session was restored from a previous session, | |||
| // convert only the current message to the prompt with the prompt template | |||
| // specified in the HistoryTransform class implementation that is provided. | |||
| ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); | |||
| prompt = HistoryTransform.HistoryToText(singleMessageHistory); | |||
| } | |||
| string assistantMessage = string.Empty; | |||
| await foreach ( | |||
| string textToken | |||
| in ChatAsyncInternal( | |||
| prompt, | |||
| inferenceParams, | |||
| cancellationToken)) | |||
| { | |||
| assistantMessage += textToken; | |||
| yield return textToken; | |||
| } | |||
| // Add the assistant message to the history | |||
| AddAssistantMessage(assistantMessage); | |||
| } | |||
| /// <summary> | |||
| /// Chat with the model. | |||
| /// </summary> | |||
| /// <param name="message"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public IAsyncEnumerable<string> ChatAsync( | |||
| ChatHistory.Message message, | |||
| IInferenceParams? inferenceParams = null, | |||
| CancellationToken cancellationToken = default) | |||
| { | |||
| return ChatAsync( | |||
| message, | |||
| applyInputTransformPipeline: true, | |||
| inferenceParams, | |||
| cancellationToken); | |||
| } | |||
| string assistantMessage = sb.ToString(); | |||
| /// <summary> | |||
| /// Chat with the model. | |||
| /// </summary> | |||
| /// <param name="history"></param> | |||
| /// <param name="applyInputTransformPipeline"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public IAsyncEnumerable<string> ChatAsync( | |||
| ChatHistory history, | |||
| bool applyInputTransformPipeline, | |||
| IInferenceParams? inferenceParams = null, | |||
| CancellationToken cancellationToken = default) | |||
| { | |||
| ChatHistory.Message lastMessage = history.Messages.LastOrDefault() | |||
| ?? throw new ArgumentException("History must contain at least one message", nameof(history)); | |||
| // Remove end tokens from the assistant message | |||
| // if defined in inferenceParams.AntiPrompts. | |||
| // We only want the response that was generated and not tokens | |||
| // that are delimiting the beginning or end of the response. | |||
| if (inferenceParams?.AntiPrompts != null) | |||
| foreach ( | |||
| ChatHistory.Message message | |||
| in history.Messages.Take(history.Messages.Count - 1)) | |||
| { | |||
| // Apply input transform pipeline | |||
| if (applyInputTransformPipeline | |||
| && message.AuthorRole == AuthorRole.User) | |||
| { | |||
| foreach (var stopToken in inferenceParams.AntiPrompts) | |||
| foreach ( | |||
| var inputTransform | |||
| in InputTransformPipeline) | |||
| { | |||
| assistantMessage = assistantMessage.Replace(stopToken, ""); | |||
| message.Content = inputTransform.Transform(message.Content); | |||
| } | |||
| } | |||
| History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage)); | |||
| AddMessage(message); | |||
| } | |||
| /// <summary> | |||
| /// Generates a response for a given chat history. This method does not manage history state for the user. | |||
| /// If you want to e.g. truncate the history of a session to fit into the model's context window, | |||
| /// use this method and pass the truncated history to it. If you don't need this control, use the other | |||
| /// overload of this method that accepts a user prompt instead. | |||
| /// </summary> | |||
| /// <param name="history"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns>Returns generated text of the assistant message.</returns> | |||
| public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| return ChatAsync( | |||
| lastMessage, | |||
| applyInputTransformPipeline, | |||
| inferenceParams, | |||
| cancellationToken); | |||
| } | |||
| /// <summary> | |||
| /// Chat with the model. | |||
| /// </summary> | |||
| /// <param name="history"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public IAsyncEnumerable<string> ChatAsync( | |||
| ChatHistory history, | |||
| IInferenceParams? inferenceParams = null, | |||
| CancellationToken cancellationToken = default) | |||
| { | |||
| return ChatAsync( | |||
| history, | |||
| applyInputTransformPipeline: true, | |||
| inferenceParams, | |||
| cancellationToken); | |||
| } | |||
| /// <summary> | |||
| /// Regenerate the last assistant message. | |||
| /// </summary> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync( | |||
| InferenceParams? inferenceParams = null, | |||
| [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| // Make sure the last message is an assistant message (reponse from the LLM). | |||
| ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); | |||
| if (lastAssistantMessage is null | |||
| || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) | |||
| { | |||
| if (history.Messages.Count == 0) | |||
| { | |||
| throw new ArgumentException("History must contain at least one message."); | |||
| } | |||
| throw new InvalidOperationException("Last message must be an assistant message"); | |||
| } | |||
| string prompt; | |||
| if (_executor is InteractiveExecutor executor) | |||
| { | |||
| InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); | |||
| // Remove the last assistant message from the history. | |||
| RemoveLastMessage(); | |||
| prompt = state.IsPromptRun | |||
| ? HistoryTransform.HistoryToText(History) | |||
| : history.Messages.Last().Content; | |||
| } | |||
| else | |||
| { | |||
| prompt = history.Messages.Last().Content; | |||
| } | |||
| // Get the last user message. | |||
| ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); | |||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| } | |||
| if (lastUserMessage is null | |||
| || lastUserMessage.AuthorRole != AuthorRole.User) | |||
| { | |||
| throw new InvalidOperationException("Last message must be a user message"); | |||
| } | |||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| // Remove the last user message from the history. | |||
| RemoveLastMessage(); | |||
| // Regenerate the assistant message. | |||
| await foreach ( | |||
| string textToken | |||
| in ChatAsync( | |||
| lastUserMessage, | |||
| applyInputTransformPipeline: false, | |||
| inferenceParams, | |||
| cancellationToken)) | |||
| { | |||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | |||
| await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) | |||
| { | |||
| yield return item; | |||
| } | |||
| yield return textToken; | |||
| } | |||
| } | |||
| private async IAsyncEnumerable<string> ChatAsyncInternal( | |||
| string prompt, | |||
| IInferenceParams? inferenceParams = null, | |||
| [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken); | |||
| await foreach ( | |||
| string textToken | |||
| in OutputTransform | |||
| .TransformAsync(results) | |||
| .WithCancellation(cancellationToken)) | |||
| { | |||
| yield return textToken; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -43,11 +46,14 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// Role of the message author, e.g. user/assistant/system | |||
| /// </summary> | |||
| [JsonConverter(typeof(JsonStringEnumConverter))] | |||
| [JsonPropertyName("author_role")] | |||
| public AuthorRole AuthorRole { get; set; } | |||
| /// <summary> | |||
| /// Message content | |||
| /// </summary> | |||
| [JsonPropertyName("content")] | |||
| public string Content { get; set; } | |||
| /// <summary> | |||
| @@ -65,15 +71,14 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// List of messages in the chat | |||
| /// </summary> | |||
| public List<Message> Messages { get; } | |||
| [JsonPropertyName("messages")] | |||
| public List<Message> Messages { get; set; } = new(); | |||
| /// <summary> | |||
| /// Create a new instance of the chat content class | |||
| /// </summary> | |||
| public ChatHistory() | |||
| { | |||
| this.Messages = new List<Message>(); | |||
| } | |||
| [JsonConstructor] | |||
| public ChatHistory() { } | |||
| /// <summary> | |||
| /// Add a message to the chat history | |||
| @@ -84,6 +89,29 @@ namespace LLama.Common | |||
| { | |||
| this.Messages.Add(new Message(authorRole, content)); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Serialize the chat history to JSON | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public string ToJson() | |||
| { | |||
| return JsonSerializer.Serialize( | |||
| this, | |||
| new JsonSerializerOptions() | |||
| { | |||
| WriteIndented = true | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Deserialize a chat history from JSON | |||
| /// </summary> | |||
| /// <param name="json"></param> | |||
| /// <returns></returns> | |||
| public static ChatHistory? FromJson(string json) | |||
| { | |||
| return JsonSerializer.Deserialize<ChatHistory>(json); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -76,6 +77,9 @@ namespace LLama.Common | |||
| /// <inheritdoc /> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <inheritdoc /> | |||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||
| } | |||
| /// <summary> | |||
| @@ -59,7 +59,6 @@ namespace LLama.Common | |||
| public bool EmbeddingMode { get; set; } | |||
| /// <inheritdoc /> | |||
| [JsonConverter(typeof(TensorSplitsCollectionConverter))] | |||
| public TensorSplitsCollection TensorSplits { get; set; } = new(); | |||
| /// <inheritdoc /> | |||
| @@ -92,9 +91,20 @@ namespace LLama.Common | |||
| /// <inheritdoc /> | |||
| public bool VocabOnly { get; set; } | |||
| /// <summary> | |||
| /// `Encoding` cannot be directly JSON serialized, instead store the name as a string which can | |||
| /// </summary> | |||
| [JsonPropertyName("Encoding")] | |||
| [JsonInclude] | |||
| private string EncodingName { get; set; } = Encoding.UTF8.WebName; | |||
| /// <inheritdoc /> | |||
| [JsonConverter(typeof(EncodingConverter))] | |||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||
| [JsonIgnore] | |||
| public Encoding Encoding | |||
| { | |||
| get => Encoding.GetEncoding(EncodingName); | |||
| set => EncodingName = value.WebName; | |||
| } | |||
| /// <summary> | |||
| /// | |||
| @@ -112,36 +122,4 @@ namespace LLama.Common | |||
| ModelPath = ""; | |||
| } | |||
| } | |||
| internal class EncodingConverter | |||
| : JsonConverter<Encoding> | |||
| { | |||
| public override Encoding? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| var name = reader.GetString(); | |||
| if (name == null) | |||
| return null; | |||
| return Encoding.GetEncoding(name); | |||
| } | |||
| public override void Write(Utf8JsonWriter writer, Encoding value, JsonSerializerOptions options) | |||
| { | |||
| writer.WriteStringValue(value.WebName); | |||
| } | |||
| } | |||
| internal class TensorSplitsCollectionConverter | |||
| : JsonConverter<TensorSplitsCollection> | |||
| { | |||
| public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); | |||
| return new TensorSplitsCollection(arr); | |||
| } | |||
| public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) | |||
| { | |||
| JsonSerializer.Serialize(writer, value.Splits, options); | |||
| } | |||
| } | |||
| } | |||
| @@ -10,6 +10,7 @@ using LLama.Common; | |||
| using System.Runtime.InteropServices; | |||
| using LLama.Extensions; | |||
| using LLama.Abstractions; | |||
| using LLama.Sampling; | |||
| using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| @@ -212,6 +213,17 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Sample a single token from this context, using the given sampling pipeline | |||
| /// </summary> | |||
| /// <param name="pipeline">The pipeline to use to process the logits and to select a token</param> | |||
| /// <param name="lastTokens">The tokens recently returned from the model</param> | |||
| /// <returns>The selected token</returns> | |||
| public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens) | |||
| { | |||
| return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||
| } | |||
| /// <summary> | |||
| /// Perform the sampling. Please don't use it unless you fully know what it does. | |||
| /// </summary> | |||
| @@ -210,16 +210,24 @@ namespace LLama | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||
| } | |||
| else | |||
| { | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| var mu = MirostatMu; | |||
| id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -189,16 +189,24 @@ namespace LLama | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||
| } | |||
| else | |||
| { | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -28,6 +28,7 @@ | |||
| <Platforms>AnyCPU;x64;Arm64</Platforms> | |||
| <PackageId>LLamaSharp</PackageId> | |||
| <Configurations>Debug;Release;GPU</Configurations> | |||
| <GenerateAssemblyInfo>false</GenerateAssemblyInfo> | |||
| </PropertyGroup> | |||
| <PropertyGroup> | |||
| @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| @@ -85,16 +86,24 @@ namespace LLama | |||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||
| { | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||
| } | |||
| else | |||
| { | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| } | |||
| // Decode this token into text | |||
| decoder.Add(id); | |||
| @@ -46,14 +46,41 @@ namespace LLama.Native | |||
| return new LLamaTokenDataArray(candidates); | |||
| } | |||
| /// <summary> | |||
| /// Overwrite the logit values for all given tokens | |||
| /// </summary> | |||
| /// <param name="values">tuples of token and logit value to overwrite</param> | |||
| public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) | |||
| { | |||
| if (values.Length == 0) | |||
| return; | |||
| var dataSpan = data.Span; | |||
| foreach (var (token, value) in values) | |||
| { | |||
| for (var i = 0; i < data.Length; i++) | |||
| { | |||
| if (dataSpan[i].id == token) | |||
| { | |||
| dataSpan[i].logit = value; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| sorted = false; | |||
| } | |||
| #region sampling | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="grammar"></param> | |||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) | |||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) | |||
| { | |||
| if (grammar == null) | |||
| return; | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| @@ -145,15 +172,17 @@ namespace LLama.Native | |||
| /// <param name="penalty_repeat"></param> | |||
| /// <param name="penalty_freq"></param> | |||
| /// <param name="penalty_present"></param> | |||
| public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||
| public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||
| { | |||
| unsafe | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| using (var last_tokens_handle = last_tokens.Pin()) | |||
| { | |||
| NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||
| sorted = st.sorted; | |||
| fixed (int* last_tokens_handle = last_tokens) | |||
| { | |||
| NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,128 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. | |||
| /// </summary> | |||
| public abstract class BaseSamplingPipeline | |||
| : ISamplingPipeline | |||
| { | |||
| private int _savedLogitsCount; | |||
| private (int index, float logit)[]? _savedLogits; | |||
| /// <inheritdoc/> | |||
| public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| var protectedLogits = GetProtectedTokens(ctx); | |||
| _savedLogitsCount = protectedLogits.Count; | |||
| _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); | |||
| try | |||
| { | |||
| // Save the values of protected logits | |||
| for (var i = 0; i < protectedLogits.Count; i++) | |||
| { | |||
| var index = protectedLogits[i]; | |||
| var value = logits[index]; | |||
| _savedLogits[i] = (index, value); | |||
| } | |||
| // Process raw logits | |||
| ProcessLogits(ctx, logits, lastTokens); | |||
| // Automatically restore saved logit values after processing | |||
| RestoreProtectedTokens(logits); | |||
| // Convert logits into token candidates | |||
| var candidates = LLamaTokenDataArray.Create(logits); | |||
| // Process token data array | |||
| ProcessTokenDataArray(ctx, candidates, lastTokens); | |||
| // Choose the final value | |||
| return ChooseToken(ctx, candidates); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<(int, float)>.Shared.Return(_savedLogits); | |||
| _savedLogits = null; | |||
| _savedLogitsCount = 0; | |||
| } | |||
| } | |||
| #region protected tokens | |||
| /// <summary> | |||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| /// </summary> | |||
| /// <param name="logits"></param> | |||
| protected void RestoreProtectedTokens(Span<float> logits) | |||
| { | |||
| if (_savedLogits == null) | |||
| return; | |||
| // The array may be bigger than necessary, get a span of the valid bit | |||
| var saved = _savedLogits.AsSpan(0, _savedLogitsCount); | |||
| // Restore the values of protected logits | |||
| for (var i = 0; i < saved.Length; i++) | |||
| logits[saved[i].index] = saved[i].logit; | |||
| } | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| /// </summary> | |||
| /// <param name="candidates"></param> | |||
| protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) | |||
| { | |||
| if (_savedLogits == null || _savedLogits.Length == 0) | |||
| return; | |||
| candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Process the raw logit values | |||
| /// </summary> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||
| /// <summary> | |||
| /// Process the LLamaTokenDataArray and select a single token | |||
| /// </summary> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="candidates">The LLamaTokenDataArray data produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens); | |||
| /// <summary> | |||
| /// Choose the final token from the candidates | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <returns></returns> | |||
| protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); | |||
| /// <inheritdoc/> | |||
| public virtual void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc/> | |||
| public virtual void Dispose() | |||
| { | |||
| GC.SuppressFinalize(this); | |||
| } | |||
| } | |||
| @@ -0,0 +1,149 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Extensions; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// An implementation of ISamplePipeline which mimics the default llama.cpp sampling | |||
| /// </summary> | |||
| public sealed class DefaultSamplingPipeline | |||
| : BaseSamplingPipeline | |||
| { | |||
| /// <summary> | |||
| /// Bias values to add to certain logits | |||
| /// </summary> | |||
| public Dictionary<int, float> LogitBias { get; } = new(); | |||
| /// <summary> | |||
| /// Grammar to constrain valid tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | |||
| /// </summary> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| /// <summary> | |||
| /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text | |||
| /// so far, decreasing the model's likelihood to repeat the same line verbatim. | |||
| /// </summary> | |||
| public float AlphaFrequency | |||
| { | |||
| get => _alphaFreq; | |||
| set | |||
| { | |||
| if (value < -2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||
| if (value > 2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||
| _alphaFreq = value; | |||
| } | |||
| } | |||
| private float _alphaFreq = 0.1f; | |||
| /// <summary> | |||
| /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the | |||
| /// text so far, increasing the model's likelihood to talk about new topics. | |||
| /// </summary> | |||
| public float AlphaPresence | |||
| { | |||
| get => _alphaPresence; | |||
| set | |||
| { | |||
| if (value < -2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||
| if (value > 2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||
| _alphaPresence = value; | |||
| } | |||
| } | |||
| private float _alphaPresence = 0.1f; | |||
| /// <summary> | |||
| /// Temperature to apply (higher temperature is more "creative") | |||
| /// </summary> | |||
| public float Temperature { get; set; } = 0.75f; | |||
| /// <summary> | |||
| /// Number of tokens to keep in TopK sampling | |||
| /// </summary> | |||
| public int TopK { get; set; } | |||
| /// <summary> | |||
| /// Z value for tail free sampling | |||
| /// </summary> | |||
| public float TailFreeZ { get; set; } | |||
| /// <summary> | |||
| /// P value for locally typical sampling | |||
| /// </summary> | |||
| public float TypicalP { get; set; } | |||
| /// <summary> | |||
| /// P value for TopP sampling | |||
| /// </summary> | |||
| public float TopP { get; set; } = 1f; | |||
| /// <summary> | |||
| /// P value for MinP sampling | |||
| /// </summary> | |||
| public float MinP { get; set; } | |||
| /// <summary> | |||
| /// Whether the newline value should be protected from being modified by logit bias and repeat penalty | |||
| /// </summary> | |||
| public bool PenalizeNewline { get; set; } = false; | |||
| private readonly int[] _newlineToken = new int[1]; | |||
| /// <inheritdoc /> | |||
| protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||
| { | |||
| if (PenalizeNewline) | |||
| return Array.Empty<int>(); | |||
| _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); | |||
| return _newlineToken; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| foreach (var (key, value) in LogitBias) | |||
| logits[key] += value; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| // Apply penalties to candidates | |||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||
| // Restore protected tokens, so they are not affected by repetition penalties | |||
| RestoreProtectedTokens(candidates); | |||
| // Apply the normal llama.cpp pipeline | |||
| candidates.ApplyGrammar(ctx, Grammar); | |||
| candidates.TopK(ctx, TopK); | |||
| candidates.TailFree(ctx, TailFreeZ); | |||
| candidates.LocallyTypical(ctx, TypicalP); | |||
| candidates.TopP(ctx, TopP); | |||
| candidates.MinP(ctx, MinP); | |||
| candidates.Temperature(ctx, Temperature); | |||
| var id = candidates.SampleToken(ctx); | |||
| Grammar?.AcceptToken(ctx, id); | |||
| return id; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| } | |||
| @@ -0,0 +1,61 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. | |||
| /// </summary> | |||
| public interface ISamplingPipeline | |||
| : IDisposable | |||
| { | |||
| /// <summary> | |||
| /// Sample a single token from the given logits | |||
| /// </summary> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||
| /// <summary> | |||
| /// Reset all internal state of the sampling pipeline | |||
| /// </summary> | |||
| void Reset(); | |||
| } | |||
| /// <summary> | |||
| /// Extensions methods for ISamplingPipeline | |||
| /// </summary> | |||
| public static class ISamplingPipelineExtensions | |||
| { | |||
| /// <summary> | |||
| /// Sample a single token from the given logits | |||
| /// </summary> | |||
| /// <param name="pipeline"></param> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens) | |||
| { | |||
| #if NET5_0_OR_GREATER | |||
| var span = CollectionsMarshal.AsSpan(lastTokens); | |||
| return pipeline.Sample(ctx, logits, span); | |||
| #else | |||
| var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count); | |||
| try | |||
| { | |||
| lastTokens.CopyTo(copy); | |||
| return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<int>.Shared.Return(copy); | |||
| } | |||
| #endif | |||
| } | |||
| } | |||