| @@ -7,6 +7,7 @@ using System; | |||||
| using System.IO; | using System.IO; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | using System.Text; | ||||
| using static LLama.InteractiveExecutor; | |||||
| using static LLama.LLamaTransforms; | using static LLama.LLamaTransforms; | ||||
| namespace LLamaSharp.SemanticKernel.ChatCompletion; | namespace LLamaSharp.SemanticKernel.ChatCompletion; | ||||
| @@ -22,6 +23,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||||
| private readonly ITextStreamTransform outputTransform; | private readonly ITextStreamTransform outputTransform; | ||||
| private readonly Dictionary<string, object?> _attributes = new(); | private readonly Dictionary<string, object?> _attributes = new(); | ||||
| private readonly bool _isStatefulExecutor; | |||||
| public IReadOnlyDictionary<string, object?> Attributes => this._attributes; | public IReadOnlyDictionary<string, object?> Attributes => this._attributes; | ||||
| @@ -42,6 +44,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||||
| ITextStreamTransform? outputTransform = null) | ITextStreamTransform? outputTransform = null) | ||||
| { | { | ||||
| this._model = model; | this._model = model; | ||||
| this._isStatefulExecutor = this._model is StatefulExecutorBase; | |||||
| this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); | this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); | ||||
| this.historyTransform = historyTransform ?? new HistoryTransform(); | this.historyTransform = historyTransform ?? new HistoryTransform(); | ||||
| this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", | this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", | ||||
| @@ -67,8 +70,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||||
| var settings = executionSettings != null | var settings = executionSettings != null | ||||
| ? ChatRequestSettings.FromRequestSettings(executionSettings) | ? ChatRequestSettings.FromRequestSettings(executionSettings) | ||||
| : defaultRequestSettings; | : defaultRequestSettings; | ||||
| var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||||
| string prompt = this._getFormattedPrompt(chatHistory); | |||||
| var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); | var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); | ||||
| var output = outputTransform.TransformAsync(result); | var output = outputTransform.TransformAsync(result); | ||||
| @@ -88,8 +91,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||||
| var settings = executionSettings != null | var settings = executionSettings != null | ||||
| ? ChatRequestSettings.FromRequestSettings(executionSettings) | ? ChatRequestSettings.FromRequestSettings(executionSettings) | ||||
| : defaultRequestSettings; | : defaultRequestSettings; | ||||
| var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||||
| string prompt = this._getFormattedPrompt(chatHistory); | |||||
| var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); | var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); | ||||
| var output = outputTransform.TransformAsync(result); | var output = outputTransform.TransformAsync(result); | ||||
| @@ -99,4 +102,33 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService | |||||
| yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); | yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Return either the entire formatted chatHistory or just the most recent message based on | |||||
| /// whether the model extends StatefulExecutorBase or not. | |||||
| /// </summary> | |||||
| /// <param name="chatHistory"></param> | |||||
| /// <returns>The formatted prompt</returns> | |||||
| private string _getFormattedPrompt(ChatHistory chatHistory){ | |||||
| string prompt; | |||||
| if (this._isStatefulExecutor){ | |||||
| InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData(); | |||||
| if (state.IsPromptRun) | |||||
| { | |||||
| prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||||
| } | |||||
| else | |||||
| { | |||||
| ChatHistory temp_history = new(); | |||||
| temp_history.AddUserMessage(chatHistory.Last().Content); | |||||
| prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory()); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); | |||||
| } | |||||
| return prompt; | |||||
| } | |||||
| } | } | ||||