Browse Source

Correcting non-standard way of working with PromptExecutionSettings

The extension of PromptExecutionSettings is not only for ChatCompletion, but also for text completion and text embedding.
pull/689/head
Zoli Somogyi 1 year ago
parent
commit
ab8dd0dfc7
9 changed files with 56 additions and 45 deletions
  1. +2
    -2
      LLama.Examples/Examples/SemanticKernelPrompt.cs
  2. +6
    -6
      LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
  3. +21
    -9
      LLama.SemanticKernel/ChatRequestSettings.cs
  4. +5
    -5
      LLama.SemanticKernel/ChatRequestSettingsConverter.cs
  5. +3
    -4
      LLama.SemanticKernel/ExtensionMethods.cs
  6. +2
    -3
      LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs
  7. +8
    -7
      LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs
  8. +8
    -8
      LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
  9. +1
    -1
      LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs

+ 2
- 2
LLama.Examples/Examples/SemanticKernelPrompt.cs View File

@@ -1,9 +1,9 @@
using LLama.Common;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using LLamaSharp.SemanticKernel.TextCompletion;
using Microsoft.SemanticKernel.TextGeneration;
using Microsoft.Extensions.DependencyInjection;
using LLamaSharp.SemanticKernel;

namespace LLama.Examples.Examples
{
@@ -31,7 +31,7 @@ namespace LLama.Examples.Examples

One line TLDR with the fewest words.";

ChatRequestSettings settings = new() { MaxTokens = 100 };
LLamaSharpPromptExecutionSettings settings = new() { MaxTokens = 100 };
var summarize = kernel.CreateFunctionFromPrompt(prompt, settings);

string text1 = @"


+ 6
- 6
LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs View File

@@ -17,7 +17,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public sealed class LLamaSharpChatCompletion : IChatCompletionService
{
private readonly ILLamaExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private LLamaSharpPromptExecutionSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;

@@ -25,9 +25,9 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService

public IReadOnlyDictionary<string, object?> Attributes => this._attributes;

static ChatRequestSettings GetDefaultSettings()
static LLamaSharpPromptExecutionSettings GetDefaultSettings()
{
return new ChatRequestSettings
return new LLamaSharpPromptExecutionSettings
{
MaxTokens = 256,
Temperature = 0,
@@ -37,7 +37,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
}

public LLamaSharpChatCompletion(ILLamaExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
LLamaSharpPromptExecutionSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
@@ -65,7 +65,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

@@ -86,7 +86,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());



LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs → LLama.SemanticKernel/ChatRequestSettings.cs View File

@@ -1,10 +1,22 @@
using Microsoft.SemanticKernel;

/* Unmerged change from project 'LLamaSharp.SemanticKernel (netstandard2.0)'
Before:
using Microsoft.SemanticKernel;
After:
using LLamaSharp;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
*/
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

public class ChatRequestSettings : PromptExecutionSettings
public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings
{
/// <summary>
/// Temperature controls the randomness of the completion.
@@ -68,30 +80,30 @@ public class ChatRequestSettings : PromptExecutionSettings
/// <param name="requestSettings">Template configuration</param>
/// <param name="defaultMaxTokens">Default max tokens</param>
/// <returns>An instance of OpenAIRequestSettings</returns>
public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null)
public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null)
{
if (requestSettings is null)
{
return new ChatRequestSettings()
return new LLamaSharpPromptExecutionSettings()
{
MaxTokens = defaultMaxTokens
};
}

if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings)
if (requestSettings is LLamaSharpPromptExecutionSettings requestSettingsChatRequestSettings)
{
return requestSettingsChatRequestSettings;
}

var json = JsonSerializer.Serialize(requestSettings);
var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options);
var chatRequestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, s_options);

if (chatRequestSettings is not null)
{
return chatRequestSettings;
}

throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings));
throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings));
}

private static readonly JsonSerializerOptions s_options = CreateOptions();
@@ -105,7 +117,7 @@ public class ChatRequestSettings : PromptExecutionSettings
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new ChatRequestSettingsConverter() }
Converters = { new LLamaSharpPromptExecutionSettingsConverter() }
};

return options;

LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs → LLama.SemanticKernel/ChatRequestSettingsConverter.cs View File

@@ -3,17 +3,17 @@ using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// </summary>
public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter<LLamaSharpPromptExecutionSettings>
{
/// <inheritdoc/>
public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override LLamaSharpPromptExecutionSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new ChatRequestSettings();
var requestSettings = new LLamaSharpPromptExecutionSettings();

while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
@@ -77,7 +77,7 @@ public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options)
public override void Write(Utf8JsonWriter writer, LLamaSharpPromptExecutionSettings value, JsonSerializerOptions options)
{
writer.WriteStartObject();


+ 3
- 4
LLama.SemanticKernel/ExtensionMethods.cs View File

@@ -1,5 +1,4 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

public static class ExtensionMethods
@@ -23,11 +22,11 @@ public static class ExtensionMethods
}

/// <summary>
/// Convert ChatRequestSettings to LLamaSharp InferenceParams
/// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams
/// </summary>
/// <param name="requestSettings"></param>
/// <returns></returns>
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings)
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings)
{
if (requestSettings is null)
{


+ 2
- 3
LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs View File

@@ -1,5 +1,4 @@
using LLama.Abstractions;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TextGeneration;
@@ -24,7 +23,7 @@ public sealed class LLamaSharpTextCompletion : ITextGenerationService
/// <inheritdoc/>
public async Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = ChatRequestSettings.FromRequestSettings(executionSettings);
var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings);
var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
var sb = new StringBuilder();
await foreach (var token in result)
@@ -37,7 +36,7 @@ public sealed class LLamaSharpTextCompletion : ITextGenerationService
/// <inheritdoc/>
public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = ChatRequestSettings.FromRequestSettings(executionSettings);
var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings);
var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
await foreach (var token in result)
{


+ 8
- 7
LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs View File

@@ -1,4 +1,5 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel.ChatCompletion;
using System.Text.Json;

namespace LLama.Unittest.SemanticKernel
@@ -10,11 +11,11 @@ namespace LLama.Unittest.SemanticKernel
{
// Arrange
var options = new JsonSerializerOptions();
options.Converters.Add(new ChatRequestSettingsConverter());
options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter());
var json = "{}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);
var requestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
@@ -36,7 +37,7 @@ namespace LLama.Unittest.SemanticKernel
// Arrange
var options = new JsonSerializerOptions();
options.AllowTrailingCommas = true;
options.Converters.Add(new ChatRequestSettingsConverter());
options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter());
var json = @"{
""frequency_penalty"": 0.5,
""max_tokens"": 250,
@@ -49,7 +50,7 @@ namespace LLama.Unittest.SemanticKernel
}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);
var requestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
@@ -73,7 +74,7 @@ namespace LLama.Unittest.SemanticKernel
// Arrange
var options = new JsonSerializerOptions();
options.AllowTrailingCommas = true;
options.Converters.Add(new ChatRequestSettingsConverter());
options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter());
var json = @"{
""FrequencyPenalty"": 0.5,
""MaxTokens"": 250,
@@ -86,7 +87,7 @@ namespace LLama.Unittest.SemanticKernel
}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);
var requestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);


+ 8
- 8
LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs View File

@@ -1,4 +1,4 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using LLamaSharp.SemanticKernel;
using Microsoft.SemanticKernel;

namespace LLama.Unittest.SemanticKernel
@@ -10,7 +10,7 @@ namespace LLama.Unittest.SemanticKernel
{
// Arrange
// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(null, null);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, null);

// Assert
Assert.NotNull(requestSettings);
@@ -31,7 +31,7 @@ namespace LLama.Unittest.SemanticKernel
{
// Arrange
// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(null, 200);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, 200);

// Assert
Assert.NotNull(requestSettings);
@@ -51,7 +51,7 @@ namespace LLama.Unittest.SemanticKernel
public void ChatRequestSettings_FromExistingRequestSettings()
{
// Arrange
var originalRequestSettings = new ChatRequestSettings()
var originalRequestSettings = new LLamaSharpPromptExecutionSettings()
{
FrequencyPenalty = 0.5,
MaxTokens = 100,
@@ -64,7 +64,7 @@ namespace LLama.Unittest.SemanticKernel
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
@@ -81,7 +81,7 @@ namespace LLama.Unittest.SemanticKernel
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
@@ -109,7 +109,7 @@ namespace LLama.Unittest.SemanticKernel
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
@@ -148,7 +148,7 @@ namespace LLama.Unittest.SemanticKernel
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);


+ 1
- 1
LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs View File

@@ -37,7 +37,7 @@ namespace LLamaSharp.SemanticKernel.Tests
public void ToLLamaSharpInferenceParams_StateUnderTest_ExpectedBehavior()
{
// Arrange
var requestSettings = new ChatRequestSettings();
var requestSettings = new LLamaSharpPromptExecutionSettings();

// Act
var result = ExtensionMethods.ToLLamaSharpInferenceParams(


Loading…
Cancel
Save