Browse Source

Update LLamaSharpChatCompletion Semantic Kernel inference to send only the most recent user message in SK ChatHistory instance when using StatefulExecutor models

pull/671/head
Chirag Karia 1 year ago
parent
commit
50e139b0a2
No known key found for this signature in database GPG Key ID: 412E2F9384BD87B1
1 changed files with 34 additions and 2 deletions
  1. +34
    -2
      LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs

+ 34
- 2
LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs View File

@@ -7,6 +7,7 @@ using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Text;
using static LLama.InteractiveExecutor;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
@@ -22,6 +23,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
private readonly ITextStreamTransform outputTransform;

private readonly Dictionary<string, object?> _attributes = new();
private readonly bool _isStatefulExecutor;

public IReadOnlyDictionary<string, object?> Attributes => this._attributes;

@@ -42,6 +44,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
ITextStreamTransform? outputTransform = null)
{
this._model = model;
this._isStatefulExecutor = this._model is StatefulExecutorBase;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
@@ -67,8 +70,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

string prompt = this._getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = outputTransform.TransformAsync(result);
@@ -88,8 +91,8 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

string prompt = this._getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = outputTransform.TransformAsync(result);
@@ -99,4 +102,33 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
}
}

/// <summary>
/// Return either the entire formatted chatHistory or just the most recent message based on
/// whether the model extends StatefulExecutorBase or not.
/// </summary>
/// <param name="chatHistory"></param>
/// <returns>The formatted prompt</returns>
private string _getFormattedPrompt(ChatHistory chatHistory){
string prompt;
if (this._isStatefulExecutor){
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData();
if (state.IsPromptRun)
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}
else
{
ChatHistory temp_history = new();
temp_history.AddUserMessage(chatHistory.Last().Content);
prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory());
}
}
else
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}

return prompt;
}
}

Loading…
Cancel
Save