From ab8dd0dfc7604249b70cf73334245e953377949f Mon Sep 17 00:00:00 2001 From: Zoli Somogyi Date: Wed, 24 Apr 2024 08:06:40 +0200 Subject: [PATCH] Correcting non-standard way of working with PromptExecutionSettings The extension of PromptExecutionSettings is not only for ChatCompletion, but also for text completion and text embedding. --- .../Examples/SemanticKernelPrompt.cs | 4 +-- .../LLamaSharpChatCompletion.cs | 12 ++++---- .../ChatRequestSettings.cs | 30 +++++++++++++------ .../ChatRequestSettingsConverter.cs | 10 +++---- LLama.SemanticKernel/ExtensionMethods.cs | 7 ++--- .../LLamaSharpTextCompletion.cs | 5 ++-- .../ChatRequestSettingsConverterTests.cs | 15 +++++----- .../ChatRequestSettingsTests.cs | 16 +++++----- .../SemanticKernel/ExtensionMethodsTests.cs | 2 +- 9 files changed, 56 insertions(+), 45 deletions(-) rename LLama.SemanticKernel/{ChatCompletion => }/ChatRequestSettings.cs (76%) rename LLama.SemanticKernel/{ChatCompletion => }/ChatRequestSettingsConverter.cs (88%) diff --git a/LLama.Examples/Examples/SemanticKernelPrompt.cs b/LLama.Examples/Examples/SemanticKernelPrompt.cs index fdf58b3a..38002d3d 100644 --- a/LLama.Examples/Examples/SemanticKernelPrompt.cs +++ b/LLama.Examples/Examples/SemanticKernelPrompt.cs @@ -1,9 +1,9 @@ using LLama.Common; -using LLamaSharp.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel; using LLamaSharp.SemanticKernel.TextCompletion; using Microsoft.SemanticKernel.TextGeneration; using Microsoft.Extensions.DependencyInjection; +using LLamaSharp.SemanticKernel; namespace LLama.Examples.Examples { @@ -31,7 +31,7 @@ namespace LLama.Examples.Examples One line TLDR with the fewest words."; - ChatRequestSettings settings = new() { MaxTokens = 100 }; + LLamaSharpPromptExecutionSettings settings = new() { MaxTokens = 100 }; var summarize = kernel.CreateFunctionFromPrompt(prompt, settings); string text1 = @" diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index 7bcbaf7b..26ecdccc 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -17,7 +17,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; public sealed class LLamaSharpChatCompletion : IChatCompletionService { private readonly ILLamaExecutor _model; - private ChatRequestSettings defaultRequestSettings; + private LLamaSharpPromptExecutionSettings defaultRequestSettings; private readonly IHistoryTransform historyTransform; private readonly ITextStreamTransform outputTransform; @@ -25,9 +25,9 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService public IReadOnlyDictionary Attributes => this._attributes; - static ChatRequestSettings GetDefaultSettings() + static LLamaSharpPromptExecutionSettings GetDefaultSettings() { - return new ChatRequestSettings + return new LLamaSharpPromptExecutionSettings { MaxTokens = 256, Temperature = 0, @@ -37,7 +37,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService } public LLamaSharpChatCompletion(ILLamaExecutor model, - ChatRequestSettings? defaultRequestSettings = default, + LLamaSharpPromptExecutionSettings? defaultRequestSettings = default, IHistoryTransform? historyTransform = null, ITextStreamTransform? outputTransform = null) { @@ -65,7 +65,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService public async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { var settings = executionSettings != null - ? ChatRequestSettings.FromRequestSettings(executionSettings) + ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) : defaultRequestSettings; var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); @@ -86,7 +86,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var settings = executionSettings != null - ? ChatRequestSettings.FromRequestSettings(executionSettings) + ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) : defaultRequestSettings; var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs b/LLama.SemanticKernel/ChatRequestSettings.cs similarity index 76% rename from LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs rename to LLama.SemanticKernel/ChatRequestSettings.cs index ac22e1fc..87dda39e 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs +++ b/LLama.SemanticKernel/ChatRequestSettings.cs @@ -1,10 +1,22 @@ -using Microsoft.SemanticKernel; + +/* Unmerged change from project 'LLamaSharp.SemanticKernel (netstandard2.0)' +Before: +using Microsoft.SemanticKernel; +After: +using LLamaSharp; +using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel; +*/ +using LLamaSharp.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel; using System.Text.Json; using System.Text.Json.Serialization; -namespace LLamaSharp.SemanticKernel.ChatCompletion; +namespace LLamaSharp.SemanticKernel; -public class ChatRequestSettings : PromptExecutionSettings +public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings { /// /// Temperature controls the randomness of the completion. @@ -68,30 +80,30 @@ public class ChatRequestSettings : PromptExecutionSettings /// Template configuration /// Default max tokens /// An instance of OpenAIRequestSettings - public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null) + public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null) { if (requestSettings is null) { - return new ChatRequestSettings() + return new LLamaSharpPromptExecutionSettings() { MaxTokens = defaultMaxTokens }; } - if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings) + if (requestSettings is LLamaSharpPromptExecutionSettings requestSettingsChatRequestSettings) { return requestSettingsChatRequestSettings; } var json = JsonSerializer.Serialize(requestSettings); - var chatRequestSettings = JsonSerializer.Deserialize(json, s_options); + var chatRequestSettings = JsonSerializer.Deserialize(json, s_options); if (chatRequestSettings is not null) { return chatRequestSettings; } - throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings)); + throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings)); } private static readonly JsonSerializerOptions s_options = CreateOptions(); @@ -105,7 +117,7 @@ public class ChatRequestSettings : PromptExecutionSettings AllowTrailingCommas = true, PropertyNameCaseInsensitive = true, ReadCommentHandling = JsonCommentHandling.Skip, - Converters = { new ChatRequestSettingsConverter() } + Converters = { new LLamaSharpPromptExecutionSettingsConverter() } }; return options; diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs b/LLama.SemanticKernel/ChatRequestSettingsConverter.cs similarity index 88% rename from LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs rename to LLama.SemanticKernel/ChatRequestSettingsConverter.cs index e320ea3f..36ca9c6c 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs +++ b/LLama.SemanticKernel/ChatRequestSettingsConverter.cs @@ -3,17 +3,17 @@ using System.Collections.Generic; using System.Text.Json; using System.Text.Json.Serialization; -namespace LLamaSharp.SemanticKernel.ChatCompletion; +namespace LLamaSharp.SemanticKernel; /// /// JSON converter for /// -public class ChatRequestSettingsConverter : JsonConverter +public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter { /// - public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override LLamaSharpPromptExecutionSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - var requestSettings = new ChatRequestSettings(); + var requestSettings = new LLamaSharpPromptExecutionSettings(); while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) { @@ -77,7 +77,7 @@ public class ChatRequestSettingsConverter : JsonConverter } /// - public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, LLamaSharpPromptExecutionSettings value, JsonSerializerOptions options) { writer.WriteStartObject(); diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs index 85f9064c..086999aa 100644 --- a/LLama.SemanticKernel/ExtensionMethods.cs +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -1,5 +1,4 @@ -using LLamaSharp.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.ChatCompletion; namespace LLamaSharp.SemanticKernel; public static class ExtensionMethods @@ -23,11 +22,11 @@ public static class ExtensionMethods } /// - /// Convert ChatRequestSettings to LLamaSharp InferenceParams + /// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams /// /// /// - internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings) + internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings) { if (requestSettings is null) { diff --git a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs index 08ec33e1..31e07b2b 100644 --- a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs +++ b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs @@ -1,5 +1,4 @@ using LLama.Abstractions; -using LLamaSharp.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Services; using Microsoft.SemanticKernel.TextGeneration; @@ -24,7 +23,7 @@ public sealed class LLamaSharpTextCompletion : ITextGenerationService /// public async Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { - var settings = ChatRequestSettings.FromRequestSettings(executionSettings); + var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings); var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken); var sb = new StringBuilder(); await foreach (var token in result) @@ -37,7 +36,7 @@ public sealed class LLamaSharpTextCompletion : ITextGenerationService /// public async IAsyncEnumerable GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var settings = ChatRequestSettings.FromRequestSettings(executionSettings); + var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings); var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken); await foreach (var token in result) { diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs index 4190e852..4828a407 100644 --- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs @@ -1,4 +1,5 @@ -using LLamaSharp.SemanticKernel.ChatCompletion; +using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel.ChatCompletion; using System.Text.Json; namespace LLama.Unittest.SemanticKernel @@ -10,11 +11,11 @@ namespace LLama.Unittest.SemanticKernel { // Arrange var options = new JsonSerializerOptions(); - options.Converters.Add(new ChatRequestSettingsConverter()); + options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter()); var json = "{}"; // Act - var requestSettings = JsonSerializer.Deserialize(json, options); + var requestSettings = JsonSerializer.Deserialize(json, options); // Assert Assert.NotNull(requestSettings); @@ -36,7 +37,7 @@ namespace LLama.Unittest.SemanticKernel // Arrange var options = new JsonSerializerOptions(); options.AllowTrailingCommas = true; - options.Converters.Add(new ChatRequestSettingsConverter()); + options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter()); var json = @"{ ""frequency_penalty"": 0.5, ""max_tokens"": 250, @@ -49,7 +50,7 @@ namespace LLama.Unittest.SemanticKernel }"; // Act - var requestSettings = JsonSerializer.Deserialize(json, options); + var requestSettings = JsonSerializer.Deserialize(json, options); // Assert Assert.NotNull(requestSettings); @@ -73,7 +74,7 @@ namespace LLama.Unittest.SemanticKernel // Arrange var options = new JsonSerializerOptions(); options.AllowTrailingCommas = true; - options.Converters.Add(new ChatRequestSettingsConverter()); + options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter()); var json = @"{ ""FrequencyPenalty"": 0.5, ""MaxTokens"": 250, @@ -86,7 +87,7 @@ namespace LLama.Unittest.SemanticKernel }"; // Act - var requestSettings = JsonSerializer.Deserialize(json, options); + var requestSettings = JsonSerializer.Deserialize(json, options); // Assert Assert.NotNull(requestSettings); diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs index ef5d9670..d75a8d4b 100644 --- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs @@ -1,4 +1,4 @@ -using LLamaSharp.SemanticKernel.ChatCompletion; +using LLamaSharp.SemanticKernel; using Microsoft.SemanticKernel; namespace LLama.Unittest.SemanticKernel @@ -10,7 +10,7 @@ namespace LLama.Unittest.SemanticKernel { // Arrange // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(null, null); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, null); // Assert Assert.NotNull(requestSettings); @@ -31,7 +31,7 @@ namespace LLama.Unittest.SemanticKernel { // Arrange // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(null, 200); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, 200); // Assert Assert.NotNull(requestSettings); @@ -51,7 +51,7 @@ namespace LLama.Unittest.SemanticKernel public void ChatRequestSettings_FromExistingRequestSettings() { // Arrange - var originalRequestSettings = new ChatRequestSettings() + var originalRequestSettings = new LLamaSharpPromptExecutionSettings() { FrequencyPenalty = 0.5, MaxTokens = 100, @@ -64,7 +64,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); @@ -81,7 +81,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); @@ -109,7 +109,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); @@ -148,7 +148,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); diff --git a/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs b/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs index dfcef182..574611fc 100644 --- a/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs +++ b/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs @@ -37,7 +37,7 @@ namespace LLamaSharp.SemanticKernel.Tests public void ToLLamaSharpInferenceParams_StateUnderTest_ExpectedBehavior() { // Arrange - var requestSettings = new ChatRequestSettings(); + var requestSettings = new LLamaSharpPromptExecutionSettings(); // Act var result = ExtensionMethods.ToLLamaSharpInferenceParams(