Browse Source

Prevent duplication of user prompts / chat history in ChatSession.

The way ChatSession.ChatAsync was using the provided methods
from a IHistoryTransform interface implementation created unexpected
duplication of the chat history messages. It also prevented loading
previous history into the current session.
tags/v0.8.0
Philipp Bauer 2 years ago
parent
commit
a288e7c02b
1 changed files with 37 additions and 13 deletions
  1. +37
    -13
      LLama/ChatSession.cs

+ 37
- 13
LLama/ChatSession.cs View File

@@ -95,11 +95,11 @@ namespace LLama
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
if(Executor is StatelessExecutor)
if (Executor is StatelessExecutor)
{

}
else if(Executor is StatefulExecutorBase statefulExecutor)
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
@@ -135,30 +135,54 @@ namespace LLama
}

/// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask.
/// Generates a response for a given user prompt and manages history state for the user.
/// This will always pass the whole history to the model. Don't pass a whole history
/// to this method as the user prompt will be appended to the history of the current session.
/// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <returns>Returns generated tokens of the assistant message.</returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach(var inputTransform in InputTransformPipeline)
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);

History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));

string internalPrompt = HistoryTransform.HistoryToText(History);

StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))

await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);

string assistantMessage = sb.ToString();

// Remove end tokens from the assistant message
// if defined in inferenceParams.AntiPrompts.
// We only want the response that was generated and not tokens
// that are delimiting the beginning or end of the response.
if (inferenceParams?.AntiPrompts != null)
{
foreach (var stopToken in inferenceParams.AntiPrompts)
{
assistantMessage = assistantMessage.Replace(stopToken, "");
}
}

History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
}

/// <summary>
/// Get the response from the LLama model with chat histories.
/// Generates a response for a given chat history. This method does not manage history state for the user.
/// If you want to e.g. truncate the history of a session to fit into the model's context window,
/// use this method and pass the truncated history to it. If you don't need this control, use the other
/// overload of this method that accepts a user prompt instead.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
@@ -167,14 +191,14 @@ namespace LLama
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();

await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)


Loading…
Cancel
Save