diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index ec5f725e..a8169be1 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -140,7 +140,7 @@ jobs: - build: 'arm64' defines: '-DCMAKE_OSX_ARCHITECTURES=arm64' - build: 'x64' - defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF' + defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF -DLLAMA_AVX=ON -DLLAMA_AVX2=ON' runs-on: macos-latest steps: - uses: actions/checkout@v3 diff --git a/LLama.Examples/Assets/chat-with-bob.json b/LLama.Examples/Assets/chat-with-bob.json new file mode 100644 index 00000000..52dc3910 --- /dev/null +++ b/LLama.Examples/Assets/chat-with-bob.json @@ -0,0 +1,24 @@ +{ + "messages": [ + { + "author_role": "System", + "content": "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision." + }, + { + "author_role": "User", + "content": "Hello, Bob." + }, + { + "author_role": "Assistant", + "content": "Hello. How may I help you today?" + }, + { + "author_role": "User", + "content": "Please tell me the largest city in Europe." + }, + { + "author_role": "Assistant", + "content": "Sure. The largest city in Europe is Istanbul, Turkey." + } + ] +} diff --git a/LLama.Examples/Assets/chat-with-kunkun-chinese.json b/LLama.Examples/Assets/chat-with-kunkun-chinese.json new file mode 100644 index 00000000..cae03029 --- /dev/null +++ b/LLama.Examples/Assets/chat-with-kunkun-chinese.json @@ -0,0 +1,24 @@ +{ + "messages": [ + { + "author_role": "System", + "content": "下面是一段你和用户的对话,你叫坤坤,是一个在各方面都拥有丰富经验的助理,你非常乐于回答用户的问题和帮助用户。" + }, + { + "author_role": "User", + "content": "你好,坤坤。" + }, + { + "author_role": "Assistant", + "content": "你好,有什么我能帮助你的吗?" + }, + { + "author_role": "User", + "content": "中国的首都是哪座城市?" + }, + { + "author_role": "Assistant", + "content": "中国的首都是北京市。" + } + ] +} diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index ff27b962..3a9fe6c7 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -1,69 +1,122 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; +using System.Text; using LLama.Common; -namespace LLama.Examples.Examples +namespace LLama.Examples.Examples; + +public class ChatChineseGB2312 { - public class ChatChineseGB2312 + private static string ConvertEncoding(string input, Encoding original, Encoding target) + { + byte[] bytes = original.GetBytes(input); + var convertedBytes = Encoding.Convert(original, target, bytes); + return target.GetString(convertedBytes); + } + + public static async Task Run() { - private static string ConvertFromEncodingToAnother(string input, Encoding original, Encoding target) + // Register provider for GB2312 encoding + Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" + + " to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers."); + Console.ForegroundColor = ConsoleColor.White; + + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5, + Encoding = Encoding.UTF8 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + ChatSession session; + if (Directory.Exists("Assets/chat-with-kunkun-chinese")) + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Loading session from disk."); + Console.ForegroundColor = ConsoleColor.White; + + session = new ChatSession(executor); + session.LoadSession("Assets/chat-with-kunkun-chinese"); + } + else { - byte[] bytes = original.GetBytes(input); - var convertedBytes = Encoding.Convert(original, target, bytes); - return target.GetString(convertedBytes); + var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + session = new ChatSession(executor, chatHistory); } - public static async Task Run() + session + .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); + + InferenceParams inferenceParams = new InferenceParams() { - Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); // Register gb2312 encoding - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-kunkun-chinese.txt", encoding: Encoding.GetEncoding("gb2312")).Trim(); - prompt = ConvertFromEncodingToAnother(prompt, Encoding.GetEncoding("gb2312"), Encoding.UTF8); + Temperature = 0.9f, + AntiPrompts = new List { "用户:" } + }; - var parameters = new ModelParams(modelPath) - { - ContextSize = 1024, - Seed = 1337, - GpuLayerCount = 20, - Encoding = Encoding.UTF8 - }; - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var executor = new InteractiveExecutor(context); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); - var session = new ChatSession(executor).WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户")); + // show the prompt + Console.ForegroundColor = ConsoleColor.White; + Console.Write("用户:"); + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" + - " to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers."); - Console.ForegroundColor = ConsoleColor.White; + while (userInput != "exit") + { + // Convert the encoding from gb2312 to utf8 for the language model + // and later saving to the history json file. + userInput = ConvertEncoding(userInput, Encoding.GetEncoding("gb2312"), Encoding.UTF8); - // show the prompt - Console.Write(prompt); - while (true) + if (userInput == "save") { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() + session.SaveSession("Assets/chat-with-kunkun-chinese"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session saved."); + } + else if (userInput == "regenerate") + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Regenerating last response ..."); + + await foreach ( + var text + in session.RegenerateAssistantMessageAsync( + inferenceParams)) { - Temperature = 0.3f, - TopK = 5, - TopP = 0.85f, - AntiPrompts = new List { "用户:" }, - MaxTokens = 2048, - RepeatPenalty = 1.05f - })) + Console.ForegroundColor = ConsoleColor.White; + + // Convert the encoding from utf8 to gb2312 for the console output. + Console.Write(ConvertEncoding(text, Encoding.UTF8, Encoding.GetEncoding("gb2312"))); + } + } + else + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) { - //Console.Write(text); - Console.Write(ConvertFromEncodingToAnother(text, Encoding.UTF8, Encoding.GetEncoding("gb2312"))); + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); } - - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; } } } diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index 41362c4a..1246db59 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -1,44 +1,61 @@ using LLama.Common; -namespace LLama.Examples.Examples +namespace LLama.Examples.Examples; + +public class ChatSessionStripRoleName { - public class ChatSessionStripRoleName + public static async Task Run() { - public static async Task Run() + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) { - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); - var parameters = new ModelParams(modelPath) - { - ContextSize = 1024, - Seed = 1337, - GpuLayerCount = 5 - }; - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var executor = new InteractiveExecutor(context); - - var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started. The role names won't be printed."); - Console.ForegroundColor = ConsoleColor.White; + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); - // show the prompt - Console.Write(prompt); - while (true) - { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) - { - Console.Write(text); - } + ChatSession session = new(executor, chatHistory); + session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:" }, + redundancyLength: 8)); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; } } } diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs new file mode 100644 index 00000000..98ba7d75 --- /dev/null +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -0,0 +1,98 @@ +using LLama.Common; + +namespace LLama.Examples.Examples; + +public class ChatSessionWithHistory +{ + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + ChatSession session; + if (Directory.Exists("Assets/chat-with-bob")) + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Loading session from disk."); + Console.ForegroundColor = ConsoleColor.White; + + session = new ChatSession(executor); + session.LoadSession("Assets/chat-with-bob"); + } + else + { + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + session = new ChatSession(executor, chatHistory); + } + + session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:" }, + redundancyLength: 8)); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + if (userInput == "save") + { + session.SaveSession("Assets/chat-with-bob"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session saved."); + } + else if (userInput == "regenerate") + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Regenerating last response ..."); + + await foreach ( + var text + in session.RegenerateAssistantMessageAsync( + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + else + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } +} diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index c9ea9023..d6b0d98e 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -1,44 +1,58 @@ using LLama.Common; -namespace LLama.Examples.Examples +namespace LLama.Examples.Examples; + +public class ChatSessionWithRoleName { - public class ChatSessionWithRoleName + public static async Task Run() { - public static async Task Run() + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) { - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); - var parameters = new ModelParams(modelPath) - { - ContextSize = 1024, - Seed = 1337, - GpuLayerCount = 5 - }; - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var executor = new InteractiveExecutor(context); - - var session = new ChatSession(executor); - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); - Console.ForegroundColor = ConsoleColor.White; + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); - // show the prompt - Console.Write(prompt); - while (true) - { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) - { - Console.Write(text); - } + ChatSession session = new(executor, chatHistory); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; } } } diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs index 91068091..678d3eb9 100644 --- a/LLama.Examples/Examples/LoadAndSaveSession.cs +++ b/LLama.Examples/Examples/LoadAndSaveSession.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using DocumentFormat.OpenXml.Bibliography; +using LLama.Common; namespace LLama.Examples.Examples { @@ -30,7 +31,15 @@ namespace LLama.Examples.Examples Console.Write(prompt); while (true) { - await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, prompt), + new InferenceParams() + { + Temperature = 0.6f, + AntiPrompts = new List { "User:" } + })) { Console.Write(text); } diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index 0ccce20e..3d9858e1 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -6,8 +6,10 @@ public class Runner { private static readonly Dictionary> Examples = new() { + { "Run a chat session with history.", ChatSessionWithHistory.Run }, { "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run }, { "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run }, + { "Run a chat session in Chinese GB2312 encoding", ChatChineseGB2312.Run }, { "Interactive mode chat by using executor.", InteractiveModeExecute.Run }, { "Instruct mode chat by using executor.", InstructModeExecute.Run }, { "Stateless mode chat by using executor.", StatelessModeExecute.Run }, @@ -23,7 +25,6 @@ public class Runner { "Coding Assistant.", CodingAssistant.Run }, { "Batch Decoding.", BatchedDecoding.Run }, { "SK Kernel Memory.", KernelMemory.Run }, - { "Chinese gb2312 chat", ChatChineseGB2312.Run }, { "Exit", async () => Environment.Exit(0) } }; diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index 2266bdcf..94704e01 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -2,7 +2,7 @@ Exe - net6.0 + net6.0;net7.0;net8.0 enable enable AnyCPU;x64 @@ -27,6 +27,12 @@ + + PreserveNewest + + + PreserveNewest + PreserveNewest diff --git a/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj b/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj index bf3280a3..a9bb5073 100644 --- a/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj +++ b/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj @@ -1,7 +1,7 @@ - net6.0;net7.0 + net6.0;net7.0;net8.0 enable enable 0.8.0 diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 9ad77531..389563aa 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -1,5 +1,4 @@ -using System.Text; -using LLama.Exceptions; +using LLama.Exceptions; using LLama.Native; using LLama.Grammars; diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index fbaee5ed..a206a23d 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -1,7 +1,7 @@ - + - net6.0 + net8.0 LLama.Unittest enable AnyCPU;x64 @@ -15,8 +15,8 @@ - - + + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index aec4b5a3..34c9c21b 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -1,4 +1,5 @@ using LLama.Common; +using System.Text.Json; namespace LLama.Unittest { @@ -16,14 +17,19 @@ namespace LLama.Unittest TensorSplits = { [0] = 3 } }; - var json = System.Text.Json.JsonSerializer.Serialize(expected); - var actual = System.Text.Json.JsonSerializer.Deserialize(json)!; + var json = JsonSerializer.Serialize(expected); + var actual = JsonSerializer.Deserialize(json)!; // Cannot compare splits with default equality, check they are sequence equal and then set to null Assert.Equal((IEnumerable)expected.TensorSplits, expected.TensorSplits); actual.TensorSplits = null!; expected.TensorSplits = null!; + // Check encoding is the same + var b1 = expected.Encoding.GetBytes("Hello"); + var b2 = actual.Encoding.GetBytes("Hello"); + Assert.True(b1.SequenceEqual(b2)); + Assert.Equal(expected, actual); } diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 195cc4a2..72e9acf8 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using LLama.Common; +using LLama.Sampling; using Xunit.Abstractions; namespace LLama.Unittest @@ -30,10 +31,13 @@ namespace LLama.Unittest [Fact] public async Task Stateless() { + // Create a custom pipeline that mimics the default pipeline + var pipeline = new DefaultSamplingPipeline(); + var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; - var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; var timer = new Stopwatch(); timer.Start(); diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 89d94ade..c604dc0d 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -1,6 +1,9 @@ -using LLama.Common; +#nullable enable + +using LLama.Common; using LLama.Abstractions; using LLama.Native; +using LLama.Sampling; namespace LLama.Web.Common { @@ -64,6 +67,9 @@ namespace LLama.Web.Common /// /// A grammar to constrain possible tokens /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } } diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index f1eb3538..f45c98ee 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -11,8 +11,7 @@ public class StatefulChatService : IDisposable private readonly LLamaContext _context; private bool _continue = false; - private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" - + "User: "; + private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision."; public StatefulChatService(IConfiguration configuration) { @@ -25,7 +24,9 @@ public class StatefulChatService : IDisposable using var weights = LLamaWeights.LoadFromFile(@params); _context = new LLamaContext(weights, @params); + _session = new ChatSession(new InteractiveExecutor(_context)); + _session.History.AddMessage(Common.AuthorRole.System, SystemPrompt); } public void Dispose() @@ -35,10 +36,8 @@ public class StatefulChatService : IDisposable public async Task Send(SendMessageInput input) { - var userInput = input.Text; if (!_continue) { - userInput = SystemPrompt + userInput; Console.Write(SystemPrompt); _continue = true; } @@ -47,11 +46,14 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() - { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, - }); + var outputs = _session.ChatAsync( + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), + new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + var result = ""; await foreach (var output in outputs) { @@ -64,10 +66,8 @@ public class StatefulChatService : IDisposable public async IAsyncEnumerable SendStream(SendMessageInput input) { - var userInput = input.Text; if (!_continue) { - userInput = SystemPrompt + userInput; Console.Write(SystemPrompt); _continue = true; } @@ -76,11 +76,14 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() - { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, - }); + var outputs = _session.ChatAsync( + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text) + , new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + await foreach (var output in outputs) { Console.Write(output); diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index d87faf0e..e1e89414 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using LLama.Common; using LLama.Native; +using LLama.Sampling; namespace LLama.Abstractions { @@ -108,5 +109,10 @@ namespace LLama.Abstractions /// Grammar to constrain possible tokens /// SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! + /// + ISamplingPipeline? SamplingPipeline { get; set; } } } \ No newline at end of file diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 2ecfe49c..4a3dde7a 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -3,6 +3,9 @@ using System.Buffers; using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using LLama.Common; using LLama.Native; namespace LLama.Abstractions @@ -105,6 +108,7 @@ namespace LLama.Abstractions /// /// A fixed size array to set the tensor splits across multiple GPUs /// + [JsonConverter(typeof(TensorSplitsCollectionConverter))] public sealed class TensorSplitsCollection : IEnumerable { @@ -174,4 +178,24 @@ namespace LLama.Abstractions } #endregion } + + /// + /// A JSON converter for + /// + public class TensorSplitsCollectionConverter + : JsonConverter + { + /// + public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); + return new TensorSplitsCollection(arr); + } + + /// + public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value.Splits, options); + } + } } \ No newline at end of file diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 5c535a6b..2985bd5f 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -1,246 +1,496 @@ -using LLama.Abstractions; -using LLama.Common; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using static LLama.InteractiveExecutor; -namespace LLama +namespace LLama; + +/// +/// The main chat session class. +/// +public class ChatSession { + private const string _modelStateFilename = "ModelState.st"; + private const string _executorStateFilename = "ExecutorState.json"; + private const string _hsitoryFilename = "ChatHistory.json"; + /// - /// The main chat session class. - /// - public class ChatSession - { - private readonly ILLamaExecutor _executor; - private readonly ChatHistory _history; - - private const string _executorStateFilename = "ExecutorState.json"; - private const string _modelStateFilename = "ModelState.st"; - - /// - /// The executor for this session. - /// - public ILLamaExecutor Executor => _executor; - /// - /// The chat history for this session. - /// - public ChatHistory History => _history; - /// - /// The history transform used in this session. - /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); - /// - /// The input transform pipeline used in this session. - /// - public List InputTransformPipeline { get; set; } = new(); - /// - /// The output transform used in this session. - /// - public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); - - /// - /// - /// - /// The executor for this session - public ChatSession(ILLamaExecutor executor) - { - _executor = executor; - _history = new ChatHistory(); - } - - /// - /// Use a custom history transform. - /// - /// - /// - public ChatSession WithHistoryTransform(IHistoryTransform transform) - { - HistoryTransform = transform; - return this; - } - - /// - /// Add a text transform to the input transform pipeline. - /// - /// - /// - public ChatSession AddInputTransform(ITextTransform transform) - { - InputTransformPipeline.Add(transform); - return this; - } - - /// - /// Use a custom output transform. - /// - /// - /// - public ChatSession WithOutputTransform(ITextStreamTransform transform) - { - OutputTransform = transform; - return this; - } - - /// - /// - /// - /// The directory name to save the session. If the directory does not exist, a new directory will be created. - public virtual void SaveSession(string path) - { - if (!Directory.Exists(path)) - { - Directory.CreateDirectory(path); - } - _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); - if (Executor is StatelessExecutor) - { + /// The executor for this session. + /// + public ILLamaExecutor Executor { get; private set; } - } - else if (Executor is StatefulExecutorBase statefulExecutor) - { - statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); - } - else - { - throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); - } + /// + /// The chat history for this session. + /// + public ChatHistory History { get; private set; } = new(); + + /// + /// The history transform used in this session. + /// + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + + /// + /// The input transform pipeline used in this session. + /// + public List InputTransformPipeline { get; set; } = new(); + + /// + /// The output transform used in this session. + /// + public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); + + /// + /// Create a new chat session. + /// + /// The executor for this session + public ChatSession(ILLamaExecutor executor) + { + // Check if executor has StatefulExecutorBase as base class + if (executor is not StatefulExecutorBase) + { + throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } - /// - /// - /// - /// The directory name to load the session. - public virtual void LoadSession(string path) + Executor = executor; + } + + /// + /// Create a new chat session with a custom history. + /// + /// + /// + public ChatSession(ILLamaExecutor executor, ChatHistory history) + : this(executor) + { + History = history; + } + + /// + /// Use a custom history transform. + /// + /// + /// + public ChatSession WithHistoryTransform(IHistoryTransform transform) + { + HistoryTransform = transform; + return this; + } + + /// + /// Add a text transform to the input transform pipeline. + /// + /// + /// + public ChatSession AddInputTransform(ITextTransform transform) + { + InputTransformPipeline.Add(transform); + return this; + } + + /// + /// Use a custom output transform. + /// + /// + /// + public ChatSession WithOutputTransform(ITextStreamTransform transform) + { + OutputTransform = transform; + return this; + } + + /// + /// Save a session from a directory. + /// + /// + /// + /// + public void SaveSession(string path) + { + if (string.IsNullOrWhiteSpace(path)) { - if (!Directory.Exists(path)) - { - throw new FileNotFoundException($"Directory {path} does not exist."); - } - _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); - if (Executor is StatelessExecutor) - { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } - } - else if (Executor is StatefulExecutorBase statefulExecutor) - { - statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); - } - else - { - throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); - } + if (Directory.Exists(path)) + { + Directory.Delete(path, recursive: true); } - /// - /// Generates a response for a given user prompt and manages history state for the user. - /// This will always pass the whole history to the model. Don't pass a whole history - /// to this method as the user prompt will be appended to the history of the current session. - /// If more control is needed, use the other overload of this method that accepts a ChatHistory object. - /// - /// - /// - /// - /// Returns generated text of the assistant message. - public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + Directory.CreateDirectory(path); + + string modelStateFilePath = Path.Combine(path, _modelStateFilename); + Executor.Context.SaveState(modelStateFilePath); + + string executorStateFilepath = Path.Combine(path, _executorStateFilename); + ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); + + string historyFilepath = Path.Combine(path, _hsitoryFilename); + File.WriteAllText(historyFilepath, History.ToJson()); + } + + /// + /// Load a session from a directory. + /// + /// + /// + /// + public void LoadSession(string path) + { + if (string.IsNullOrWhiteSpace(path)) { - foreach (var inputTransform in InputTransformPipeline) - prompt = inputTransform.Transform(prompt); + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (!Directory.Exists(path)) + { + throw new ArgumentException("Directory does not exist", nameof(path)); + } + + string modelStateFilePath = Path.Combine(path, _modelStateFilename); + Executor.Context.LoadState(modelStateFilePath); - // TODO: need to be refactored. - if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun) + string executorStateFilepath = Path.Combine(path, _executorStateFilename); + ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); + + string historyFilepath = Path.Combine(path, _hsitoryFilename); + string historyJson = File.ReadAllText(historyFilepath); + History = ChatHistory.FromJson(historyJson) + ?? throw new ArgumentException("History file is invalid", nameof(path)); + } + + /// + /// Add a message to the chat history. + /// + /// + /// + public ChatSession AddMessage(ChatHistory.Message message) + { + // If current message is a system message, only allow the history to be empty + if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) + { + throw new ArgumentException("Cannot add a system message after another message", nameof(message)); + } + + // If current message is a user message, only allow the history to be empty, + // or the previous message to be a system message or assistant message. + if (message.AuthorRole == AuthorRole.User) + { + ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { - History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt)); - var converted_prompt = HistoryTransform.HistoryToText(History); - // Avoid missing anti-prompt. - if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n")) - { - prompt = converted_prompt.Trim(); - } - else - { - prompt = converted_prompt; - } + throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); } - else + } + + // If the current message is an assistant message, + // the previous message must be a user message. + if (message.AuthorRole == AuthorRole.Assistant) + { + ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + if (lastMessage is null + || lastMessage.AuthorRole != AuthorRole.User) { - History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); + throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message)); } + } + + History.AddMessage(message.AuthorRole, message.Content); + return this; + } + + /// + /// Add a system message to the chat history. + /// + /// + /// + public ChatSession AddSystemMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); + + /// + /// Add an assistant message to the chat history. + /// + /// + /// + public ChatSession AddAssistantMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + + /// + /// Add a user message to the chat history. + /// + /// + /// + public ChatSession AddUserMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); - StringBuilder sb = new(); + /// + /// Remove the last message from the chat history. + /// + /// + public ChatSession RemoveLastMessage() + { + History.Messages.RemoveAt(History.Messages.Count - 1); + return this; + } + + /// + /// Replace a user message with a new message and remove all messages after the new message. + /// This is useful when the user wants to edit a message. And regenerate the response. + /// + /// + /// + /// + public ChatSession ReplaceUserMessage( + ChatHistory.Message oldMessage, + ChatHistory.Message newMessage) + { + if (oldMessage.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("Old message must be a user message", nameof(oldMessage)); + } - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + if (newMessage.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("New message must be a user message", nameof(newMessage)); + } + + int index = History.Messages.IndexOf(oldMessage); + if (index == -1) + { + throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); + } + + History.Messages[index] = newMessage; + + // Remove all message after the new message + History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); + + return this; + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + /// + /// + public async IAsyncEnumerable ChatAsync( + ChatHistory.Message message, + bool applyInputTransformPipeline, + IInferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // The message must be a user message + if (message.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("Message must be a user message", nameof(message)); + } + + // Apply input transform pipeline + if (applyInputTransformPipeline) + { + foreach (var inputTransform in InputTransformPipeline) { - yield return result; - sb.Append(result); + message.Content = inputTransform.Transform(message.Content); } + } + + // Add the user's message to the history + AddUserMessage(message.Content); + + // Prepare prompt variable + string prompt; + + // Check if the session history was restored from a previous session + // or added as part of new chat session history. + InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData(); + + // If "IsPromptRun" is true, the session was newly started. + if (state.IsPromptRun) + { + // If the session history was added as part of new chat session history, + // convert the complete history includsing system message and manually added history + // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. + prompt = HistoryTransform.HistoryToText(History); + } + else + { + // If the session was restored from a previous session, + // convert only the current message to the prompt with the prompt template + // specified in the HistoryTransform class implementation that is provided. + ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); + prompt = HistoryTransform.HistoryToText(singleMessageHistory); + } + + string assistantMessage = string.Empty; + + await foreach ( + string textToken + in ChatAsyncInternal( + prompt, + inferenceParams, + cancellationToken)) + { + assistantMessage += textToken; + yield return textToken; + } + + // Add the assistant message to the history + AddAssistantMessage(assistantMessage); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory.Message message, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + return ChatAsync( + message, + applyInputTransformPipeline: true, + inferenceParams, + cancellationToken); + } - string assistantMessage = sb.ToString(); + /// + /// Chat with the model. + /// + /// + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory history, + bool applyInputTransformPipeline, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + ChatHistory.Message lastMessage = history.Messages.LastOrDefault() + ?? throw new ArgumentException("History must contain at least one message", nameof(history)); - // Remove end tokens from the assistant message - // if defined in inferenceParams.AntiPrompts. - // We only want the response that was generated and not tokens - // that are delimiting the beginning or end of the response. - if (inferenceParams?.AntiPrompts != null) + foreach ( + ChatHistory.Message message + in history.Messages.Take(history.Messages.Count - 1)) + { + // Apply input transform pipeline + if (applyInputTransformPipeline + && message.AuthorRole == AuthorRole.User) { - foreach (var stopToken in inferenceParams.AntiPrompts) + foreach ( + var inputTransform + in InputTransformPipeline) { - assistantMessage = assistantMessage.Replace(stopToken, ""); + message.Content = inputTransform.Transform(message.Content); } } - History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage)); + AddMessage(message); } - /// - /// Generates a response for a given chat history. This method does not manage history state for the user. - /// If you want to e.g. truncate the history of a session to fit into the model's context window, - /// use this method and pass the truncated history to it. If you don't need this control, use the other - /// overload of this method that accepts a user prompt instead. - /// - /// - /// - /// - /// Returns generated text of the assistant message. - public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + return ChatAsync( + lastMessage, + applyInputTransformPipeline, + inferenceParams, + cancellationToken); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory history, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + return ChatAsync( + history, + applyInputTransformPipeline: true, + inferenceParams, + cancellationToken); + } + + /// + /// Regenerate the last assistant message. + /// + /// + /// + /// + /// + public async IAsyncEnumerable RegenerateAssistantMessageAsync( + InferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Make sure the last message is an assistant message (reponse from the LLM). + ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); + + if (lastAssistantMessage is null + || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) { - if (history.Messages.Count == 0) - { - throw new ArgumentException("History must contain at least one message."); - } + throw new InvalidOperationException("Last message must be an assistant message"); + } - string prompt; - if (_executor is InteractiveExecutor executor) - { - InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); + // Remove the last assistant message from the history. + RemoveLastMessage(); - prompt = state.IsPromptRun - ? HistoryTransform.HistoryToText(History) - : history.Messages.Last().Content; - } - else - { - prompt = history.Messages.Last().Content; - } + // Get the last user message. + ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) - { - yield return result; - } + if (lastUserMessage is null + || lastUserMessage.AuthorRole != AuthorRole.User) + { + throw new InvalidOperationException("Last message must be a user message"); } - private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + // Remove the last user message from the history. + RemoveLastMessage(); + + // Regenerate the assistant message. + await foreach ( + string textToken + in ChatAsync( + lastUserMessage, + applyInputTransformPipeline: false, + inferenceParams, + cancellationToken)) { - var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); - await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) - { - yield return item; - } + yield return textToken; + } + } + + private async IAsyncEnumerable ChatAsyncInternal( + string prompt, + IInferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken); + + await foreach ( + string textToken + in OutputTransform + .TransformAsync(results) + .WithCancellation(cancellationToken)) + { + yield return textToken; } } -} \ No newline at end of file +} diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index 7224b314..3f038874 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,7 @@ using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using System.Text.Json.Serialization; namespace LLama.Common { @@ -43,11 +46,14 @@ namespace LLama.Common /// /// Role of the message author, e.g. user/assistant/system /// + [JsonConverter(typeof(JsonStringEnumConverter))] + [JsonPropertyName("author_role")] public AuthorRole AuthorRole { get; set; } /// /// Message content /// + [JsonPropertyName("content")] public string Content { get; set; } /// @@ -65,15 +71,14 @@ namespace LLama.Common /// /// List of messages in the chat /// - public List Messages { get; } + [JsonPropertyName("messages")] + public List Messages { get; set; } = new(); /// /// Create a new instance of the chat content class /// - public ChatHistory() - { - this.Messages = new List(); - } + [JsonConstructor] + public ChatHistory() { } /// /// Add a message to the chat history @@ -84,6 +89,29 @@ namespace LLama.Common { this.Messages.Add(new Message(authorRole, content)); } - } + /// + /// Serialize the chat history to JSON + /// + /// + public string ToJson() + { + return JsonSerializer.Serialize( + this, + new JsonSerializerOptions() + { + WriteIndented = true + }); + } + + /// + /// Deserialize a chat history from JSON + /// + /// + /// + public static ChatHistory? FromJson(string json) + { + return JsonSerializer.Deserialize(json); + } + } } diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index d7bd19d9..c1f39550 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using LLama.Native; +using LLama.Sampling; namespace LLama.Common { @@ -76,6 +77,9 @@ namespace LLama.Common /// public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } /// diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index f1cef072..cecd655a 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -59,7 +59,6 @@ namespace LLama.Common public bool EmbeddingMode { get; set; } /// - [JsonConverter(typeof(TensorSplitsCollectionConverter))] public TensorSplitsCollection TensorSplits { get; set; } = new(); /// @@ -92,9 +91,20 @@ namespace LLama.Common /// public bool VocabOnly { get; set; } + /// + /// `Encoding` cannot be directly JSON serialized, instead store the name as a string which can + /// + [JsonPropertyName("Encoding")] + [JsonInclude] + private string EncodingName { get; set; } = Encoding.UTF8.WebName; + /// - [JsonConverter(typeof(EncodingConverter))] - public Encoding Encoding { get; set; } = Encoding.UTF8; + [JsonIgnore] + public Encoding Encoding + { + get => Encoding.GetEncoding(EncodingName); + set => EncodingName = value.WebName; + } /// /// @@ -112,36 +122,4 @@ namespace LLama.Common ModelPath = ""; } } - - internal class EncodingConverter - : JsonConverter - { - public override Encoding? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var name = reader.GetString(); - if (name == null) - return null; - return Encoding.GetEncoding(name); - } - - public override void Write(Utf8JsonWriter writer, Encoding value, JsonSerializerOptions options) - { - writer.WriteStringValue(value.WebName); - } - } - - internal class TensorSplitsCollectionConverter - : JsonConverter - { - public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); - return new TensorSplitsCollection(arr); - } - - public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) - { - JsonSerializer.Serialize(writer, value.Splits, options); - } - } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 3a3e51af..2902dc8f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -10,6 +10,7 @@ using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; using LLama.Abstractions; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -212,6 +213,17 @@ namespace LLama } } + /// + /// Sample a single token from this context, using the given sampling pipeline + /// + /// The pipeline to use to process the logits and to select a token + /// The tokens recently returned from the model + /// The selected token + public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) + { + return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + } + /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d81630aa..3ed66890 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -210,16 +210,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4d28274b..9cecf437 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -189,16 +189,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 0e029c2d..5e7de5f4 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -28,6 +28,7 @@ AnyCPU;x64;Arm64 LLamaSharp Debug;Release;GPU + false diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 9c41af7c..831aceb2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using LLama.Native; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -85,16 +86,24 @@ namespace LLama var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - // Sample a single token - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + } + else + { + // Penalize the generated tokens by various penalties + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + // Sample a single token + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + } // Decode this token into text decoder.Add(id); diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 4bc154f4..5059a5f3 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -46,14 +46,41 @@ namespace LLama.Native return new LLamaTokenDataArray(candidates); } + /// + /// Overwrite the logit values for all given tokens + /// + /// tuples of token and logit value to overwrite + public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) + { + if (values.Length == 0) + return; + + var dataSpan = data.Span; + foreach (var (token, value) in values) + { + for (var i = 0; i < data.Length; i++) + { + if (dataSpan[i].id == token) + { + dataSpan[i].logit = value; + break; + } + } + } + sorted = false; + } + #region sampling /// /// Apply grammar rules to candidate tokens /// /// /// - public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) + public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) { + if (grammar == null) + return; + using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_grammar(ctx, ref st, grammar); @@ -145,15 +172,17 @@ namespace LLama.Native /// /// /// - public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) - using (var last_tokens_handle = last_tokens.Pin()) { - NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); - sorted = st.sorted; + fixed (int* last_tokens_handle = last_tokens) + { + NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); + sorted = st.sorted; + } } } } diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs new file mode 100644 index 00000000..4c0f7689 --- /dev/null +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -0,0 +1,128 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. +/// +public abstract class BaseSamplingPipeline + : ISamplingPipeline +{ + private int _savedLogitsCount; + private (int index, float logit)[]? _savedLogits; + + /// + public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + var protectedLogits = GetProtectedTokens(ctx); + _savedLogitsCount = protectedLogits.Count; + _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); + try + { + // Save the values of protected logits + for (var i = 0; i < protectedLogits.Count; i++) + { + var index = protectedLogits[i]; + var value = logits[index]; + _savedLogits[i] = (index, value); + } + + // Process raw logits + ProcessLogits(ctx, logits, lastTokens); + + // Automatically restore saved logit values after processing + RestoreProtectedTokens(logits); + + // Convert logits into token candidates + var candidates = LLamaTokenDataArray.Create(logits); + + // Process token data array + ProcessTokenDataArray(ctx, candidates, lastTokens); + + // Choose the final value + return ChooseToken(ctx, candidates); + } + finally + { + ArrayPool<(int, float)>.Shared.Return(_savedLogits); + _savedLogits = null; + _savedLogitsCount = 0; + } + } + + #region protected tokens + /// + /// Get all of the "protected" tokens that cannot be changed by ProcessLogits + /// + /// + protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(Span logits) + { + if (_savedLogits == null) + return; + + // The array may be bigger than necessary, get a span of the valid bit + var saved = _savedLogits.AsSpan(0, _savedLogitsCount); + + // Restore the values of protected logits + for (var i = 0; i < saved.Length; i++) + logits[saved[i].index] = saved[i].logit; + } + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) + { + if (_savedLogits == null || _savedLogits.Length == 0) + return; + + candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); + } + #endregion + + /// + /// Process the raw logit values + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Process the LLamaTokenDataArray and select a single token + /// + /// The context being sampled from + /// The LLamaTokenDataArray data produced by the model + /// A list of tokens recently returned by the model + /// + protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + + /// + /// Choose the final token from the candidates + /// + /// + /// + /// + protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); + + /// + public virtual void Reset() + { + } + + /// + public virtual void Dispose() + { + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs new file mode 100644 index 00000000..e6db2efe --- /dev/null +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using LLama.Extensions; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling +/// +public sealed class DefaultSamplingPipeline + : BaseSamplingPipeline +{ + /// + /// Bias values to add to certain logits + /// + public Dictionary LogitBias { get; } = new(); + + /// + /// Grammar to constrain valid tokens + /// + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 + /// + public float RepeatPenalty { get; set; } = 1.1f; + + /// + /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text + /// so far, decreasing the model's likelihood to repeat the same line verbatim. + ///
+ public float AlphaFrequency + { + get => _alphaFreq; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaFreq = value; + } + } + private float _alphaFreq = 0.1f; + + /// + /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + /// text so far, increasing the model's likelihood to talk about new topics. + ///
+ public float AlphaPresence + { + get => _alphaPresence; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaPresence = value; + } + } + private float _alphaPresence = 0.1f; + + /// + /// Temperature to apply (higher temperature is more "creative") + /// + public float Temperature { get; set; } = 0.75f; + + /// + /// Number of tokens to keep in TopK sampling + /// + public int TopK { get; set; } + + /// + /// Z value for tail free sampling + /// + public float TailFreeZ { get; set; } + + /// + /// P value for locally typical sampling + /// + public float TypicalP { get; set; } + + /// + /// P value for TopP sampling + /// + public float TopP { get; set; } = 1f; + + /// + /// P value for MinP sampling + /// + public float MinP { get; set; } + + /// + /// Whether the newline value should be protected from being modified by logit bias and repeat penalty + /// + public bool PenalizeNewline { get; set; } = false; + + private readonly int[] _newlineToken = new int[1]; + + /// + protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + { + if (PenalizeNewline) + return Array.Empty(); + + _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); + return _newlineToken; + } + + /// + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var (key, value) in LogitBias) + logits[key] += value; + } + + /// + protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + // Apply penalties to candidates + candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + + // Restore protected tokens, so they are not affected by repetition penalties + RestoreProtectedTokens(candidates); + + // Apply the normal llama.cpp pipeline + candidates.ApplyGrammar(ctx, Grammar); + candidates.TopK(ctx, TopK); + candidates.TailFree(ctx, TailFreeZ); + candidates.LocallyTypical(ctx, TypicalP); + candidates.TopP(ctx, TopP); + candidates.MinP(ctx, MinP); + candidates.Temperature(ctx, Temperature); + var id = candidates.SampleToken(ctx); + + Grammar?.AcceptToken(ctx, id); + return id; + } + + /// + protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + { + return candidates.SampleToken(ctx); + } +} \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs new file mode 100644 index 00000000..f39bf996 --- /dev/null +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -0,0 +1,61 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. +/// +public interface ISamplingPipeline + : IDisposable +{ + /// + /// Sample a single token from the given logits + /// + /// The context being sampled from + /// The logits produced by the model + /// A span of tokens recently returned by the model + /// + int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Reset all internal state of the sampling pipeline + /// + void Reset(); +} + +/// +/// Extensions methods for ISamplingPipeline +/// +public static class ISamplingPipelineExtensions +{ + /// + /// Sample a single token from the given logits + /// + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + /// + public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(lastTokens); + return pipeline.Sample(ctx, logits, span); +#else + var copy = ArrayPool.Shared.Rent(lastTokens.Count); + try + { + lastTokens.CopyTo(copy); + return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); + } + finally + { + ArrayPool.Shared.Return(copy); + } +#endif + } +} \ No newline at end of file