| @@ -7,6 +7,7 @@ using System; | |||
| using System.IO; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using static LLama.InteractiveExecutor; | |||
| using static LLama.LLamaTransforms; | |||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | |||
| @@ -22,6 +23,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||
| private readonly ITextStreamTransform outputTransform; | |||
| private readonly Dictionary<string, object?> _attributes = new(); | |||
| private readonly bool _isStatefulExecutor; | |||
| public IReadOnlyDictionary<string, object?> Attributes => this._attributes; | |||
| @@ -42,6 +44,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||
| ITextStreamTransform? outputTransform = null) | |||
| { | |||
| this._model = model; | |||
| this._isStatefulExecutor = this._model is StatefulExecutorBase; | |||
| this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); | |||
| this.historyTransform = historyTransform ?? new HistoryTransform(); | |||
| this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", | |||
| @@ -67,8 +70,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||
| var settings = executionSettings != null | |||
| ? ChatRequestSettings.FromRequestSettings(executionSettings) | |||
| : defaultRequestSettings; | |||
| var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||
| string prompt = this._getFormattedPrompt(chatHistory); | |||
| var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); | |||
| var output = outputTransform.TransformAsync(result); | |||
| @@ -88,8 +91,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||
| var settings = executionSettings != null | |||
| ? ChatRequestSettings.FromRequestSettings(executionSettings) | |||
| : defaultRequestSettings; | |||
| var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||
| string prompt = this._getFormattedPrompt(chatHistory); | |||
| var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); | |||
| var output = outputTransform.TransformAsync(result); | |||
| @@ -99,4 +102,33 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||
| yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Return either the entire formatted chatHistory or just the most recent message based on | |||
| /// whether the model extends StatefulExecutorBase or not. | |||
| /// </summary> | |||
| /// <param name="chatHistory"></param> | |||
| /// <returns>The formatted prompt</returns> | |||
| private string _getFormattedPrompt(ChatHistory chatHistory){ | |||
| string prompt; | |||
| if (this._isStatefulExecutor){ | |||
| InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData(); | |||
| if (state.IsPromptRun) | |||
| { | |||
| prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||
| } | |||
| else | |||
| { | |||
| ChatHistory temp_history = new(); | |||
| temp_history.AddUserMessage(chatHistory.Last().Content); | |||
| prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory()); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||
| } | |||
| return prompt; | |||
| } | |||
| } | |||