Browse Source

🔧 Refactor chat completion implementation

- Refactored the chat completion implementation in `LLamaSharpChatCompletion.cs` to use `StatelessExecutor` instead of `InteractiveExecutor`.
- Updated the chat history prompt in `LLamaSharpChatCompletion.cs` to include a conversation between the assistant and the user.
- Modified the `HistoryTransform` class in `HistoryTransform.cs` to append the assistant role to the chat history prompt.
- Updated the constructor of `LLamaSharpChatCompletion` to accept optional parameters for `historyTransform` and `outputTransform`.
- Modified the `GetChatCompletionsAsync` and `GetChatCompletions` methods in `LLamaSharpChatCompletion.cs` to use the new `StatelessExecutor` and `outputTransform`.
- Updated the `ExtensionMethods.cs` file to include the assistant and system roles in the list of anti-prompts.
tags/0.9.1
xbotter 1 year ago
parent
commit
a2b26faa7a
No known key found for this signature in database GPG Key ID: D299220A7FE5CF1E
4 changed files with 32 additions and 29 deletions
  1. +2
    -3
      LLama.Examples/Examples/SemanticKernelChat.cs
  2. +4
    -4
      LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
  3. +22
    -21
      LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
  4. +4
    -1
      LLama.SemanticKernel/ExtensionMethods.cs

+ 2
- 3
LLama.Examples/Examples/SemanticKernelChat.cs View File

@@ -16,12 +16,11 @@ namespace LLama.Examples.Examples
// Load weights into memory
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
var ex = new StatelessExecutor(model, parameters);

var chatGPT = new LLamaSharpChatCompletion(ex);

var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books");
var chatHistory = chatGPT.CreateNewChat("This is a conversation between the assistant and the user. \n\n You are a librarian, expert about books. ");

Console.WriteLine("Chat content:");
Console.WriteLine("------------------------");


+ 4
- 4
LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs View File

@@ -1,4 +1,6 @@
using static LLama.LLamaTransforms;
using LLama.Common;
using System.Text;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

@@ -10,8 +12,6 @@ public class HistoryTransform : DefaultHistoryTransform
/// <inheritdoc/>
public override string HistoryToText(global::LLama.Common.ChatHistory history)
{
var prompt = base.HistoryToText(history);
return prompt + "\nAssistant:";

return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
}

+ 22
- 21
LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs View File

@@ -1,7 +1,9 @@
using LLama;
using LLama.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using System.Runtime.CompilerServices;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

@@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// </summary>
public sealed class LLamaSharpChatCompletion : IChatCompletion
{
private const string UserRole = "user:";
private const string AssistantRole = "assistant:";
private ChatSession session;
private readonly StatelessExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;

private readonly Dictionary<string, string> _attributes = new();

@@ -30,18 +32,17 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
};
}

public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default)
public LLamaSharpChatCompletion(StatelessExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
this.session = new ChatSession(model)
.WithHistoryTransform(new HistoryTransform())
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole }));
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
}

public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default)
{
this.session = session;
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
this._model = model;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
$"{LLama.Common.AuthorRole.Assistant}:",
$"{LLama.Common.AuthorRole.System}:"});
}

/// <inheritdoc/>
@@ -60,14 +61,14 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
/// <inheritdoc/>
public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var settings = requestSettings != null
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());

// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(result) }.AsReadOnly());
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(outputTransform.TransformAsync(result)) }.AsReadOnly());
}

/// <inheritdoc/>
@@ -78,10 +79,10 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

yield return new LLamaSharpChatResult(result);
yield return new LLamaSharpChatResult(outputTransform.TransformAsync(result));
}
}

+ 4
- 1
LLama.SemanticKernel/ExtensionMethods.cs View File

@@ -35,7 +35,10 @@ public static class ExtensionMethods
throw new ArgumentNullException(nameof(requestSettings));
}

var antiPrompts = new List<string>(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" };
var antiPrompts = new List<string>(requestSettings.StopSequences)
{ LLama.Common.AuthorRole.User.ToString() + ":" ,
LLama.Common.AuthorRole.Assistant.ToString() + ":",
LLama.Common.AuthorRole.System.ToString() + ":"};
return new global::LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,


Loading…
Cancel
Save