diff --git a/LLama.Examples/Examples/SemanticKernelPrompt.cs b/LLama.Examples/Examples/SemanticKernelPrompt.cs index 63e848cb..0d62e0b3 100644 --- a/LLama.Examples/Examples/SemanticKernelPrompt.cs +++ b/LLama.Examples/Examples/SemanticKernelPrompt.cs @@ -1,9 +1,9 @@ using LLama.Common; -using LLamaSharp.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel; using LLamaSharp.SemanticKernel.TextCompletion; using Microsoft.SemanticKernel.TextGeneration; using Microsoft.Extensions.DependencyInjection; +using LLamaSharp.SemanticKernel; namespace LLama.Examples.Examples { @@ -31,7 +31,7 @@ namespace LLama.Examples.Examples One line TLDR with the fewest words."; - ChatRequestSettings settings = new() { MaxTokens = 100 }; + LLamaSharpPromptExecutionSettings settings = new() { MaxTokens = 100 }; var summarize = kernel.CreateFunctionFromPrompt(prompt, settings); string text1 = @" diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs index ac22e1fc..683f8c45 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs +++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs @@ -4,6 +4,7 @@ using System.Text.Json.Serialization; namespace LLamaSharp.SemanticKernel.ChatCompletion; +[Obsolete("Use LLamaSharpPromptExecutionSettings instead")] public class ChatRequestSettings : PromptExecutionSettings { /// diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs index e320ea3f..15bc45cd 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs +++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs @@ -8,6 +8,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; /// /// JSON converter for /// +[Obsolete("Use LLamaSharpPromptExecutionSettingsConverter instead")] public class ChatRequestSettingsConverter : JsonConverter { /// diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index a7ac6e8e..01a061db 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -18,7 +18,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; public sealed class LLamaSharpChatCompletion : IChatCompletionService { private readonly ILLamaExecutor _model; - private ChatRequestSettings defaultRequestSettings; + private LLamaSharpPromptExecutionSettings defaultRequestSettings; private readonly IHistoryTransform historyTransform; private readonly ITextStreamTransform outputTransform; @@ -27,9 +27,9 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService public IReadOnlyDictionary Attributes => this._attributes; - static ChatRequestSettings GetDefaultSettings() + static LLamaSharpPromptExecutionSettings GetDefaultSettings() { - return new ChatRequestSettings + return new LLamaSharpPromptExecutionSettings { MaxTokens = 256, Temperature = 0, @@ -39,7 +39,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService } public LLamaSharpChatCompletion(ILLamaExecutor model, - ChatRequestSettings? defaultRequestSettings = default, + LLamaSharpPromptExecutionSettings? defaultRequestSettings = default, IHistoryTransform? historyTransform = null, ITextStreamTransform? outputTransform = null) { @@ -68,7 +68,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService public async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { var settings = executionSettings != null - ? ChatRequestSettings.FromRequestSettings(executionSettings) + ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) : defaultRequestSettings; string prompt = this._getFormattedPrompt(chatHistory); @@ -89,7 +89,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var settings = executionSettings != null - ? ChatRequestSettings.FromRequestSettings(executionSettings) + ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) : defaultRequestSettings; string prompt = this._getFormattedPrompt(chatHistory); diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs index 85f9064c..086999aa 100644 --- a/LLama.SemanticKernel/ExtensionMethods.cs +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -1,5 +1,4 @@ -using LLamaSharp.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.ChatCompletion; namespace LLamaSharp.SemanticKernel; public static class ExtensionMethods @@ -23,11 +22,11 @@ public static class ExtensionMethods } /// - /// Convert ChatRequestSettings to LLamaSharp InferenceParams + /// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams /// /// /// - internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings) + internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings) { if (requestSettings is null) { diff --git a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs new file mode 100644 index 00000000..5e8a6669 --- /dev/null +++ b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs @@ -0,0 +1,131 @@ + +/* Unmerged change from project 'LLamaSharp.SemanticKernel (netstandard2.0)' +Before: +using Microsoft.SemanticKernel; +After: +using LLamaSharp; +using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel; +*/ +using LLamaSharp.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LLamaSharp.SemanticKernel; + +public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings +{ + /// + /// Temperature controls the randomness of the completion. + /// The higher the temperature, the more random the completion. + /// + [JsonPropertyName("temperature")] + public double Temperature { get; set; } = 0; + + /// + /// TopP controls the diversity of the completion. + /// The higher the TopP, the more diverse the completion. + /// + [JsonPropertyName("top_p")] + public double TopP { get; set; } = 0; + + /// + /// Number between -2.0 and 2.0. Positive values penalize new tokens + /// based on whether they appear in the text so far, increasing the + /// model's likelihood to talk about new topics. + /// + [JsonPropertyName("presence_penalty")] + public double PresencePenalty { get; set; } = 0; + + /// + /// Number between -2.0 and 2.0. Positive values penalize new tokens + /// based on their existing frequency in the text so far, decreasing + /// the model's likelihood to repeat the same line verbatim. + /// + [JsonPropertyName("frequency_penalty")] + public double FrequencyPenalty { get; set; } = 0; + + /// + /// Sequences where the completion will stop generating further tokens. + /// + [JsonPropertyName("stop_sequences")] + public IList StopSequences { get; set; } = Array.Empty(); + + /// + /// How many completions to generate for each prompt. Default is 1. + /// Note: Because this parameter generates many completions, it can quickly consume your token quota. + /// Use carefully and ensure that you have reasonable settings for max_tokens and stop. + /// + [JsonPropertyName("results_per_prompt")] + public int ResultsPerPrompt { get; set; } = 1; + + /// + /// The maximum number of tokens to generate in the completion. + /// + [JsonPropertyName("max_tokens")] + public int? MaxTokens { get; set; } + + /// + /// Modify the likelihood of specified tokens appearing in the completion. + /// + [JsonPropertyName("token_selection_biases")] + public IDictionary TokenSelectionBiases { get; set; } = new Dictionary(); + + /// + /// Indicates the format of the response which can be used downstream to post-process the messages. Handlebars: handlebars_object. JSON: json_object, etc. + /// + [JsonPropertyName("response_format")] + public string ResponseFormat { get; set; } = string.Empty; + + /// + /// Create a new settings object with the values from another settings object. + /// + /// Template configuration + /// Default max tokens + /// An instance of OpenAIRequestSettings + public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null) + { + if (requestSettings is null) + { + return new LLamaSharpPromptExecutionSettings() + { + MaxTokens = defaultMaxTokens + }; + } + + if (requestSettings is LLamaSharpPromptExecutionSettings requestSettingsChatRequestSettings) + { + return requestSettingsChatRequestSettings; + } + + var json = JsonSerializer.Serialize(requestSettings); + var chatRequestSettings = JsonSerializer.Deserialize(json, s_options); + + if (chatRequestSettings is not null) + { + return chatRequestSettings; + } + + throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings)); + } + + private static readonly JsonSerializerOptions s_options = CreateOptions(); + + private static JsonSerializerOptions CreateOptions() + { + JsonSerializerOptions options = new() + { + WriteIndented = true, + MaxDepth = 20, + AllowTrailingCommas = true, + PropertyNameCaseInsensitive = true, + ReadCommentHandling = JsonCommentHandling.Skip, + Converters = { new LLamaSharpPromptExecutionSettingsConverter() } + }; + + return options; + } +} diff --git a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs new file mode 100644 index 00000000..36ca9c6c --- /dev/null +++ b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LLamaSharp.SemanticKernel; + +/// +/// JSON converter for +/// +public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter +{ + /// + public override LLamaSharpPromptExecutionSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var requestSettings = new LLamaSharpPromptExecutionSettings(); + + while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) + { + if (reader.TokenType == JsonTokenType.PropertyName) + { + string? propertyName = reader.GetString(); + + if (propertyName is not null) + { + // normalise property name to uppercase + propertyName = propertyName.ToUpperInvariant(); + } + + reader.Read(); + + switch (propertyName) + { + case "MODELID": + case "MODEL_ID": + requestSettings.ModelId = reader.GetString(); + break; + case "TEMPERATURE": + requestSettings.Temperature = reader.GetDouble(); + break; + case "TOPP": + case "TOP_P": + requestSettings.TopP = reader.GetDouble(); + break; + case "FREQUENCYPENALTY": + case "FREQUENCY_PENALTY": + requestSettings.FrequencyPenalty = reader.GetDouble(); + break; + case "PRESENCEPENALTY": + case "PRESENCE_PENALTY": + requestSettings.PresencePenalty = reader.GetDouble(); + break; + case "MAXTOKENS": + case "MAX_TOKENS": + requestSettings.MaxTokens = reader.GetInt32(); + break; + case "STOPSEQUENCES": + case "STOP_SEQUENCES": + requestSettings.StopSequences = JsonSerializer.Deserialize>(ref reader, options) ?? Array.Empty(); + break; + case "RESULTSPERPROMPT": + case "RESULTS_PER_PROMPT": + requestSettings.ResultsPerPrompt = reader.GetInt32(); + break; + case "TOKENSELECTIONBIASES": + case "TOKEN_SELECTION_BIASES": + requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize>(ref reader, options) ?? new Dictionary(); + break; + default: + reader.Skip(); + break; + } + } + } + + return requestSettings; + } + + /// + public override void Write(Utf8JsonWriter writer, LLamaSharpPromptExecutionSettings value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WriteNumber("temperature", value.Temperature); + writer.WriteNumber("top_p", value.TopP); + writer.WriteNumber("frequency_penalty", value.FrequencyPenalty); + writer.WriteNumber("presence_penalty", value.PresencePenalty); + if (value.MaxTokens is null) + { + writer.WriteNull("max_tokens"); + } + else + { + writer.WriteNumber("max_tokens", (decimal)value.MaxTokens); + } + writer.WritePropertyName("stop_sequences"); + JsonSerializer.Serialize(writer, value.StopSequences, options); + writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt); + writer.WritePropertyName("token_selection_biases"); + JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options); + + writer.WriteEndObject(); + } +} \ No newline at end of file diff --git a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs index 08ec33e1..31e07b2b 100644 --- a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs +++ b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs @@ -1,5 +1,4 @@ using LLama.Abstractions; -using LLamaSharp.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Services; using Microsoft.SemanticKernel.TextGeneration; @@ -24,7 +23,7 @@ public sealed class LLamaSharpTextCompletion : ITextGenerationService /// public async Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { - var settings = ChatRequestSettings.FromRequestSettings(executionSettings); + var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings); var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken); var sb = new StringBuilder(); await foreach (var token in result) @@ -37,7 +36,7 @@ public sealed class LLamaSharpTextCompletion : ITextGenerationService /// public async IAsyncEnumerable GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var settings = ChatRequestSettings.FromRequestSettings(executionSettings); + var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings); var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken); await foreach (var token in result) { diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs index 4190e852..4828a407 100644 --- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs @@ -1,4 +1,5 @@ -using LLamaSharp.SemanticKernel.ChatCompletion; +using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel.ChatCompletion; using System.Text.Json; namespace LLama.Unittest.SemanticKernel @@ -10,11 +11,11 @@ namespace LLama.Unittest.SemanticKernel { // Arrange var options = new JsonSerializerOptions(); - options.Converters.Add(new ChatRequestSettingsConverter()); + options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter()); var json = "{}"; // Act - var requestSettings = JsonSerializer.Deserialize(json, options); + var requestSettings = JsonSerializer.Deserialize(json, options); // Assert Assert.NotNull(requestSettings); @@ -36,7 +37,7 @@ namespace LLama.Unittest.SemanticKernel // Arrange var options = new JsonSerializerOptions(); options.AllowTrailingCommas = true; - options.Converters.Add(new ChatRequestSettingsConverter()); + options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter()); var json = @"{ ""frequency_penalty"": 0.5, ""max_tokens"": 250, @@ -49,7 +50,7 @@ namespace LLama.Unittest.SemanticKernel }"; // Act - var requestSettings = JsonSerializer.Deserialize(json, options); + var requestSettings = JsonSerializer.Deserialize(json, options); // Assert Assert.NotNull(requestSettings); @@ -73,7 +74,7 @@ namespace LLama.Unittest.SemanticKernel // Arrange var options = new JsonSerializerOptions(); options.AllowTrailingCommas = true; - options.Converters.Add(new ChatRequestSettingsConverter()); + options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter()); var json = @"{ ""FrequencyPenalty"": 0.5, ""MaxTokens"": 250, @@ -86,7 +87,7 @@ namespace LLama.Unittest.SemanticKernel }"; // Act - var requestSettings = JsonSerializer.Deserialize(json, options); + var requestSettings = JsonSerializer.Deserialize(json, options); // Assert Assert.NotNull(requestSettings); diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs index ef5d9670..d75a8d4b 100644 --- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs @@ -1,4 +1,4 @@ -using LLamaSharp.SemanticKernel.ChatCompletion; +using LLamaSharp.SemanticKernel; using Microsoft.SemanticKernel; namespace LLama.Unittest.SemanticKernel @@ -10,7 +10,7 @@ namespace LLama.Unittest.SemanticKernel { // Arrange // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(null, null); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, null); // Assert Assert.NotNull(requestSettings); @@ -31,7 +31,7 @@ namespace LLama.Unittest.SemanticKernel { // Arrange // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(null, 200); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, 200); // Assert Assert.NotNull(requestSettings); @@ -51,7 +51,7 @@ namespace LLama.Unittest.SemanticKernel public void ChatRequestSettings_FromExistingRequestSettings() { // Arrange - var originalRequestSettings = new ChatRequestSettings() + var originalRequestSettings = new LLamaSharpPromptExecutionSettings() { FrequencyPenalty = 0.5, MaxTokens = 100, @@ -64,7 +64,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); @@ -81,7 +81,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); @@ -109,7 +109,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); @@ -148,7 +148,7 @@ namespace LLama.Unittest.SemanticKernel }; // Act - var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings); // Assert Assert.NotNull(requestSettings); diff --git a/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs b/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs index dfcef182..574611fc 100644 --- a/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs +++ b/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs @@ -37,7 +37,7 @@ namespace LLamaSharp.SemanticKernel.Tests public void ToLLamaSharpInferenceParams_StateUnderTest_ExpectedBehavior() { // Arrange - var requestSettings = new ChatRequestSettings(); + var requestSettings = new LLamaSharpPromptExecutionSettings(); // Act var result = ExtensionMethods.ToLLamaSharpInferenceParams(