| @@ -1,4 +1,6 @@ | |||||
| using Microsoft.SemanticKernel.AI; | using Microsoft.SemanticKernel.AI; | ||||
| using System.Text.Json; | |||||
| using System.Text.Json.Serialization; | |||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | namespace LLamaSharp.SemanticKernel.ChatCompletion; | ||||
| @@ -8,12 +10,14 @@ public class ChatRequestSettings : AIRequestSettings | |||||
| /// Temperature controls the randomness of the completion. | /// Temperature controls the randomness of the completion. | ||||
| /// The higher the temperature, the more random the completion. | /// The higher the temperature, the more random the completion. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("temperature")] | |||||
| public double Temperature { get; set; } = 0; | public double Temperature { get; set; } = 0; | ||||
| /// <summary> | /// <summary> | ||||
| /// TopP controls the diversity of the completion. | /// TopP controls the diversity of the completion. | ||||
| /// The higher the TopP, the more diverse the completion. | /// The higher the TopP, the more diverse the completion. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("top_p")] | |||||
| public double TopP { get; set; } = 0; | public double TopP { get; set; } = 0; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -21,6 +25,7 @@ public class ChatRequestSettings : AIRequestSettings | |||||
| /// based on whether they appear in the text so far, increasing the | /// based on whether they appear in the text so far, increasing the | ||||
| /// model's likelihood to talk about new topics. | /// model's likelihood to talk about new topics. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("presence_penalty")] | |||||
| public double PresencePenalty { get; set; } = 0; | public double PresencePenalty { get; set; } = 0; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -28,11 +33,13 @@ public class ChatRequestSettings : AIRequestSettings | |||||
| /// based on their existing frequency in the text so far, decreasing | /// based on their existing frequency in the text so far, decreasing | ||||
| /// the model's likelihood to repeat the same line verbatim. | /// the model's likelihood to repeat the same line verbatim. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("frequency_penalty")] | |||||
| public double FrequencyPenalty { get; set; } = 0; | public double FrequencyPenalty { get; set; } = 0; | ||||
| /// <summary> | /// <summary> | ||||
| /// Sequences where the completion will stop generating further tokens. | /// Sequences where the completion will stop generating further tokens. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("stop_sequences")] | |||||
| public IList<string> StopSequences { get; set; } = Array.Empty<string>(); | public IList<string> StopSequences { get; set; } = Array.Empty<string>(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -40,15 +47,67 @@ public class ChatRequestSettings : AIRequestSettings | |||||
| /// Note: Because this parameter generates many completions, it can quickly consume your token quota. | /// Note: Because this parameter generates many completions, it can quickly consume your token quota. | ||||
| /// Use carefully and ensure that you have reasonable settings for max_tokens and stop. | /// Use carefully and ensure that you have reasonable settings for max_tokens and stop. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("results_per_prompt")] | |||||
| public int ResultsPerPrompt { get; set; } = 1; | public int ResultsPerPrompt { get; set; } = 1; | ||||
| /// <summary> | /// <summary> | ||||
| /// The maximum number of tokens to generate in the completion. | /// The maximum number of tokens to generate in the completion. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("max_tokens")] | |||||
| public int? MaxTokens { get; set; } | public int? MaxTokens { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Modify the likelihood of specified tokens appearing in the completion. | /// Modify the likelihood of specified tokens appearing in the completion. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonPropertyName("token_selection_biases")] | |||||
| public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>(); | public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>(); | ||||
| /// <summary> | |||||
| /// Create a new settings object with the values from another settings object. | |||||
| /// </summary> | |||||
| /// <param name="requestSettings">Template configuration</param> | |||||
| /// <param name="defaultMaxTokens">Default max tokens</param> | |||||
| /// <returns>An instance of OpenAIRequestSettings</returns> | |||||
| public static ChatRequestSettings FromRequestSettings(AIRequestSettings? requestSettings, int? defaultMaxTokens = null) | |||||
| { | |||||
| if (requestSettings is null) | |||||
| { | |||||
| return new ChatRequestSettings() | |||||
| { | |||||
| MaxTokens = defaultMaxTokens | |||||
| }; | |||||
| } | |||||
| if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings) | |||||
| { | |||||
| return requestSettingsChatRequestSettings; | |||||
| } | |||||
| var json = JsonSerializer.Serialize(requestSettings); | |||||
| var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options); | |||||
| if (chatRequestSettings is not null) | |||||
| { | |||||
| return chatRequestSettings; | |||||
| } | |||||
| throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings)); | |||||
| } | |||||
| private static readonly JsonSerializerOptions s_options = CreateOptions(); | |||||
| private static JsonSerializerOptions CreateOptions() | |||||
| { | |||||
| JsonSerializerOptions options = new() | |||||
| { | |||||
| WriteIndented = true, | |||||
| MaxDepth = 20, | |||||
| AllowTrailingCommas = true, | |||||
| PropertyNameCaseInsensitive = true, | |||||
| ReadCommentHandling = JsonCommentHandling.Skip, | |||||
| Converters = { new ChatRequestSettingsConverter() } | |||||
| }; | |||||
| return options; | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,105 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text.Json; | |||||
| using System.Text.Json.Serialization; | |||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||||
| /// <summary> | |||||
| /// JSON converter for <see cref="OpenAIRequestSettings"/> | |||||
| /// </summary> | |||||
| public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings> | |||||
| { | |||||
| /// <inheritdoc/> | |||||
| public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||||
| { | |||||
| var requestSettings = new ChatRequestSettings(); | |||||
| while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) | |||||
| { | |||||
| if (reader.TokenType == JsonTokenType.PropertyName) | |||||
| { | |||||
| string? propertyName = reader.GetString(); | |||||
| if (propertyName is not null) | |||||
| { | |||||
| // normalise property name to uppercase | |||||
| propertyName = propertyName.ToUpperInvariant(); | |||||
| } | |||||
| reader.Read(); | |||||
| switch (propertyName) | |||||
| { | |||||
| case "TEMPERATURE": | |||||
| requestSettings.Temperature = reader.GetDouble(); | |||||
| break; | |||||
| case "TOPP": | |||||
| case "TOP_P": | |||||
| requestSettings.TopP = reader.GetDouble(); | |||||
| break; | |||||
| case "FREQUENCYPENALTY": | |||||
| case "FREQUENCY_PENALTY": | |||||
| requestSettings.FrequencyPenalty = reader.GetDouble(); | |||||
| break; | |||||
| case "PRESENCEPENALTY": | |||||
| case "PRESENCE_PENALTY": | |||||
| requestSettings.PresencePenalty = reader.GetDouble(); | |||||
| break; | |||||
| case "MAXTOKENS": | |||||
| case "MAX_TOKENS": | |||||
| requestSettings.MaxTokens = reader.GetInt32(); | |||||
| break; | |||||
| case "STOPSEQUENCES": | |||||
| case "STOP_SEQUENCES": | |||||
| requestSettings.StopSequences = JsonSerializer.Deserialize<IList<string>>(ref reader, options) ?? Array.Empty<string>(); | |||||
| break; | |||||
| case "RESULTSPERPROMPT": | |||||
| case "RESULTS_PER_PROMPT": | |||||
| requestSettings.ResultsPerPrompt = reader.GetInt32(); | |||||
| break; | |||||
| case "TOKENSELECTIONBIASES": | |||||
| case "TOKEN_SELECTION_BIASES": | |||||
| requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize<IDictionary<int, int>>(ref reader, options) ?? new Dictionary<int, int>(); | |||||
| break; | |||||
| case "SERVICEID": | |||||
| case "SERVICE_ID": | |||||
| requestSettings.ServiceId = reader.GetString(); | |||||
| break; | |||||
| default: | |||||
| reader.Skip(); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| return requestSettings; | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options) | |||||
| { | |||||
| writer.WriteStartObject(); | |||||
| writer.WriteNumber("temperature", value.Temperature); | |||||
| writer.WriteNumber("top_p", value.TopP); | |||||
| writer.WriteNumber("frequency_penalty", value.FrequencyPenalty); | |||||
| writer.WriteNumber("presence_penalty", value.PresencePenalty); | |||||
| if (value.MaxTokens is null) | |||||
| { | |||||
| writer.WriteNull("max_tokens"); | |||||
| } | |||||
| else | |||||
| { | |||||
| writer.WriteNumber("max_tokens", (decimal)value.MaxTokens); | |||||
| } | |||||
| writer.WritePropertyName("stop_sequences"); | |||||
| JsonSerializer.Serialize(writer, value.StopSequences, options); | |||||
| writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt); | |||||
| writer.WritePropertyName("token_selection_biases"); | |||||
| JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options); | |||||
| writer.WriteString("service_id", value.ServiceId); | |||||
| writer.WriteEndObject(); | |||||
| } | |||||
| } | |||||
| @@ -61,7 +61,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion | |||||
| public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) | public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| var settings = requestSettings != null | var settings = requestSettings != null | ||||
| ? (ChatRequestSettings)requestSettings | |||||
| ? ChatRequestSettings.FromRequestSettings(requestSettings) | |||||
| : defaultRequestSettings; | : defaultRequestSettings; | ||||
| // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. | // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. | ||||
| @@ -76,7 +76,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion | |||||
| #pragma warning restore CS1998 | #pragma warning restore CS1998 | ||||
| { | { | ||||
| var settings = requestSettings != null | var settings = requestSettings != null | ||||
| ? (ChatRequestSettings)requestSettings | |||||
| ? ChatRequestSettings.FromRequestSettings(requestSettings) | |||||
| : defaultRequestSettings; | : defaultRequestSettings; | ||||
| // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. | // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. | ||||
| @@ -21,7 +21,7 @@ public sealed class LLamaSharpTextCompletion : ITextCompletion | |||||
| public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default) | public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| var settings = (ChatRequestSettings?)requestSettings; | |||||
| var settings = ChatRequestSettings.FromRequestSettings(requestSettings); | |||||
| var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); | var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); | ||||
| return await Task.FromResult(new List<ITextResult> { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false); | return await Task.FromResult(new List<ITextResult> { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false); | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ public sealed class LLamaSharpTextCompletion : ITextCompletion | |||||
| public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default) | public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| #pragma warning restore CS1998 | #pragma warning restore CS1998 | ||||
| { | { | ||||
| var settings = (ChatRequestSettings?)requestSettings; | |||||
| var settings = ChatRequestSettings.FromRequestSettings(requestSettings); | |||||
| var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); | var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); | ||||
| yield return new LLamaTextResult(result); | yield return new LLamaTextResult(result); | ||||
| } | } | ||||