@@ -1,7 +1,9 @@
using LLama;
using LLama.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using System.Runtime.CompilerServices;
using static LLama.LLamaTransforms;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// </summary>
public sealed class LLamaSharpChatCompletion : IChatCompletion
{
private const string UserRole = "user:";
private const string AssistantRole = "assistant:";
private ChatSession session;
private readonly StatelessExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;
private readonly Dictionary<string, string> _attributes = new();
@@ -30,18 +32,17 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
};
}
public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default)
public LLamaSharpChatCompletion(StatelessExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
this.session = new ChatSession(model)
.WithHistoryTransform(new HistoryTransform())
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole }));
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
}
public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default)
{
this.session = session;
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
this._model = model;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
$"{LLama.Common.AuthorRole.Assistant}:",
$"{LLama.Common.AuthorRole.System}:"});
}
/// <inheritdoc/>
@@ -60,14 +61,14 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
/// <inheritdoc/>
public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var settings = requestSettings != null
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(result) }.AsReadOnly());
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(outputTransform.TransformAsync( result) ) }.AsReadOnly());
}
/// <inheritdoc/>
@@ -78,10 +79,10 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory() , settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt , settings.ToLLamaSharpInferenceParams(), cancellationToken);
yield return new LLamaSharpChatResult(result);
yield return new LLamaSharpChatResult(outputTransform.TransformAsync( result) );
}
}