Browse Source

Merge pull request #671 from kidkych/feature/interactive-sk-chatcompletion

Optimize Semantic Kernel LLamaSharpChatCompletion when running with StatefulExecutorBase models
pull/634/merge
Rinne GitHub 1 year ago
parent
commit
0c770a528e
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 2 deletions
  1. +34
    -2
      LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs

+ 34
- 2
LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs View File

@@ -7,6 +7,7 @@ using System;
using System.IO; using System.IO;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using static LLama.InteractiveExecutor;
using static LLama.LLamaTransforms; using static LLama.LLamaTransforms;


namespace LLamaSharp.SemanticKernel.ChatCompletion; namespace LLamaSharp.SemanticKernel.ChatCompletion;
@@ -22,6 +23,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
private readonly ITextStreamTransform outputTransform; private readonly ITextStreamTransform outputTransform;


private readonly Dictionary<string, object?> _attributes = new(); private readonly Dictionary<string, object?> _attributes = new();
private readonly bool _isStatefulExecutor;


public IReadOnlyDictionary<string, object?> Attributes => this._attributes; public IReadOnlyDictionary<string, object?> Attributes => this._attributes;


@@ -42,6 +44,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
ITextStreamTransform? outputTransform = null) ITextStreamTransform? outputTransform = null)
{ {
this._model = model; this._model = model;
this._isStatefulExecutor = this._model is StatefulExecutorBase;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform(); this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
@@ -67,8 +70,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
var settings = executionSettings != null var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings) ? ChatRequestSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings; : defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());


string prompt = this._getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);


var output = outputTransform.TransformAsync(result); var output = outputTransform.TransformAsync(result);
@@ -88,8 +91,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
var settings = executionSettings != null var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings) ? ChatRequestSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings; : defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());


string prompt = this._getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);


var output = outputTransform.TransformAsync(result); var output = outputTransform.TransformAsync(result);
@@ -99,4 +102,33 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
} }
} }

/// <summary>
/// Return either the entire formatted chatHistory or just the most recent message based on
/// whether the model extends StatefulExecutorBase or not.
/// </summary>
/// <param name="chatHistory"></param>
/// <returns>The formatted prompt</returns>
private string _getFormattedPrompt(ChatHistory chatHistory){
string prompt;
if (this._isStatefulExecutor){
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData();
if (state.IsPromptRun)
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}
else
{
ChatHistory temp_history = new();
temp_history.AddUserMessage(chatHistory.Last().Content);
prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory());
}
}
else
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}

return prompt;
}
} }

Loading…
Cancel
Save