| @@ -1,4 +1,6 @@ | |||
| using Microsoft.SemanticKernel.AI; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||
| @@ -8,12 +10,14 @@ public class ChatRequestSettings : AIRequestSettings | |||
| /// Temperature controls the randomness of the completion. | |||
| /// The higher the temperature, the more random the completion. | |||
| /// </summary> | |||
| [JsonPropertyName("temperature")] | |||
| public double Temperature { get; set; } = 0; | |||
| /// <summary> | |||
| /// TopP controls the diversity of the completion. | |||
| /// The higher the TopP, the more diverse the completion. | |||
| /// </summary> | |||
| [JsonPropertyName("top_p")] | |||
| public double TopP { get; set; } = 0; | |||
| /// <summary> | |||
| @@ -21,6 +25,7 @@ public class ChatRequestSettings : AIRequestSettings | |||
| /// based on whether they appear in the text so far, increasing the | |||
| /// model's likelihood to talk about new topics. | |||
| /// </summary> | |||
| [JsonPropertyName("presence_penalty")] | |||
| public double PresencePenalty { get; set; } = 0; | |||
| /// <summary> | |||
| @@ -28,11 +33,13 @@ public class ChatRequestSettings : AIRequestSettings | |||
| /// based on their existing frequency in the text so far, decreasing | |||
| /// the model's likelihood to repeat the same line verbatim. | |||
| /// </summary> | |||
| [JsonPropertyName("frequency_penalty")] | |||
| public double FrequencyPenalty { get; set; } = 0; | |||
| /// <summary> | |||
| /// Sequences where the completion will stop generating further tokens. | |||
| /// </summary> | |||
| [JsonPropertyName("stop_sequences")] | |||
| public IList<string> StopSequences { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| @@ -40,15 +47,67 @@ public class ChatRequestSettings : AIRequestSettings | |||
| /// Note: Because this parameter generates many completions, it can quickly consume your token quota. | |||
| /// Use carefully and ensure that you have reasonable settings for max_tokens and stop. | |||
| /// </summary> | |||
| [JsonPropertyName("results_per_prompt")] | |||
| public int ResultsPerPrompt { get; set; } = 1; | |||
| /// <summary> | |||
| /// The maximum number of tokens to generate in the completion. | |||
| /// </summary> | |||
| [JsonPropertyName("max_tokens")] | |||
| public int? MaxTokens { get; set; } | |||
| /// <summary> | |||
| /// Modify the likelihood of specified tokens appearing in the completion. | |||
| /// </summary> | |||
| [JsonPropertyName("token_selection_biases")] | |||
| public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>(); | |||
| /// <summary> | |||
| /// Create a new settings object with the values from another settings object. | |||
| /// </summary> | |||
| /// <param name="requestSettings">Template configuration</param> | |||
| /// <param name="defaultMaxTokens">Default max tokens</param> | |||
| /// <returns>An instance of OpenAIRequestSettings</returns> | |||
| public static ChatRequestSettings FromRequestSettings(AIRequestSettings? requestSettings, int? defaultMaxTokens = null) | |||
| { | |||
| if (requestSettings is null) | |||
| { | |||
| return new ChatRequestSettings() | |||
| { | |||
| MaxTokens = defaultMaxTokens | |||
| }; | |||
| } | |||
| if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings) | |||
| { | |||
| return requestSettingsChatRequestSettings; | |||
| } | |||
| var json = JsonSerializer.Serialize(requestSettings); | |||
| var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options); | |||
| if (chatRequestSettings is not null) | |||
| { | |||
| return chatRequestSettings; | |||
| } | |||
| throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings)); | |||
| } | |||
| private static readonly JsonSerializerOptions s_options = CreateOptions(); | |||
| private static JsonSerializerOptions CreateOptions() | |||
| { | |||
| JsonSerializerOptions options = new() | |||
| { | |||
| WriteIndented = true, | |||
| MaxDepth = 20, | |||
| AllowTrailingCommas = true, | |||
| PropertyNameCaseInsensitive = true, | |||
| ReadCommentHandling = JsonCommentHandling.Skip, | |||
| Converters = { new ChatRequestSettingsConverter() } | |||
| }; | |||
| return options; | |||
| } | |||
| } | |||
| @@ -0,0 +1,105 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||
| /// <summary> | |||
| /// JSON converter for <see cref="OpenAIRequestSettings"/> | |||
| /// </summary> | |||
| public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings> | |||
| { | |||
| /// <inheritdoc/> | |||
| public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| var requestSettings = new ChatRequestSettings(); | |||
| while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) | |||
| { | |||
| if (reader.TokenType == JsonTokenType.PropertyName) | |||
| { | |||
| string? propertyName = reader.GetString(); | |||
| if (propertyName is not null) | |||
| { | |||
| // normalise property name to uppercase | |||
| propertyName = propertyName.ToUpperInvariant(); | |||
| } | |||
| reader.Read(); | |||
| switch (propertyName) | |||
| { | |||
| case "TEMPERATURE": | |||
| requestSettings.Temperature = reader.GetDouble(); | |||
| break; | |||
| case "TOPP": | |||
| case "TOP_P": | |||
| requestSettings.TopP = reader.GetDouble(); | |||
| break; | |||
| case "FREQUENCYPENALTY": | |||
| case "FREQUENCY_PENALTY": | |||
| requestSettings.FrequencyPenalty = reader.GetDouble(); | |||
| break; | |||
| case "PRESENCEPENALTY": | |||
| case "PRESENCE_PENALTY": | |||
| requestSettings.PresencePenalty = reader.GetDouble(); | |||
| break; | |||
| case "MAXTOKENS": | |||
| case "MAX_TOKENS": | |||
| requestSettings.MaxTokens = reader.GetInt32(); | |||
| break; | |||
| case "STOPSEQUENCES": | |||
| case "STOP_SEQUENCES": | |||
| requestSettings.StopSequences = JsonSerializer.Deserialize<IList<string>>(ref reader, options) ?? Array.Empty<string>(); | |||
| break; | |||
| case "RESULTSPERPROMPT": | |||
| case "RESULTS_PER_PROMPT": | |||
| requestSettings.ResultsPerPrompt = reader.GetInt32(); | |||
| break; | |||
| case "TOKENSELECTIONBIASES": | |||
| case "TOKEN_SELECTION_BIASES": | |||
| requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize<IDictionary<int, int>>(ref reader, options) ?? new Dictionary<int, int>(); | |||
| break; | |||
| case "SERVICEID": | |||
| case "SERVICE_ID": | |||
| requestSettings.ServiceId = reader.GetString(); | |||
| break; | |||
| default: | |||
| reader.Skip(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| return requestSettings; | |||
| } | |||
| /// <inheritdoc/> | |||
| public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options) | |||
| { | |||
| writer.WriteStartObject(); | |||
| writer.WriteNumber("temperature", value.Temperature); | |||
| writer.WriteNumber("top_p", value.TopP); | |||
| writer.WriteNumber("frequency_penalty", value.FrequencyPenalty); | |||
| writer.WriteNumber("presence_penalty", value.PresencePenalty); | |||
| if (value.MaxTokens is null) | |||
| { | |||
| writer.WriteNull("max_tokens"); | |||
| } | |||
| else | |||
| { | |||
| writer.WriteNumber("max_tokens", (decimal)value.MaxTokens); | |||
| } | |||
| writer.WritePropertyName("stop_sequences"); | |||
| JsonSerializer.Serialize(writer, value.StopSequences, options); | |||
| writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt); | |||
| writer.WritePropertyName("token_selection_biases"); | |||
| JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options); | |||
| writer.WriteString("service_id", value.ServiceId); | |||
| writer.WriteEndObject(); | |||
| } | |||
| } | |||
| @@ -61,7 +61,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion | |||
| public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) | |||
| { | |||
| var settings = requestSettings != null | |||
| ? (ChatRequestSettings)requestSettings | |||
| ? ChatRequestSettings.FromRequestSettings(requestSettings) | |||
| : defaultRequestSettings; | |||
| // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. | |||
| @@ -76,7 +76,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion | |||
| #pragma warning restore CS1998 | |||
| { | |||
| var settings = requestSettings != null | |||
| ? (ChatRequestSettings)requestSettings | |||
| ? ChatRequestSettings.FromRequestSettings(requestSettings) | |||
| : defaultRequestSettings; | |||
| // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. | |||
| @@ -21,7 +21,7 @@ public sealed class LLamaSharpTextCompletion : ITextCompletion | |||
| public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default) | |||
| { | |||
| var settings = (ChatRequestSettings?)requestSettings; | |||
| var settings = ChatRequestSettings.FromRequestSettings(requestSettings); | |||
| var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); | |||
| return await Task.FromResult(new List<ITextResult> { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false); | |||
| } | |||
| @@ -30,7 +30,7 @@ public sealed class LLamaSharpTextCompletion : ITextCompletion | |||
| public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| #pragma warning restore CS1998 | |||
| { | |||
| var settings = (ChatRequestSettings?)requestSettings; | |||
| var settings = ChatRequestSettings.FromRequestSettings(requestSettings); | |||
| var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); | |||
| yield return new LLamaTextResult(result); | |||
| } | |||