Browse Source

Correctly format followup messages in turn-based (chat) inference

tags/0.9.1
Philipp Bauer 2 years ago
parent
commit
629430a087
1 changed files with 23 additions and 13 deletions
  1. +23
    -13
      LLama/ChatSession.cs

+ 23
- 13
LLama/ChatSession.cs View File

@@ -159,15 +159,15 @@ namespace LLama
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: prompt;
: HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt));
}

StringBuilder sb = new();

await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
yield return textToken;
sb.Append(textToken);
}

string assistantMessage = sb.ToString();
@@ -180,7 +180,7 @@ namespace LLama
{
foreach (var stopToken in inferenceParams.AntiPrompts)
{
assistantMessage = assistantMessage.Replace(stopToken, "");
assistantMessage = assistantMessage.Replace(stopToken, "").Trim();
}
}

@@ -209,27 +209,37 @@ namespace LLama
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();

prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: history.Messages.Last().Content;
if (state.IsPromptRun)
{
prompt = HistoryTransform.HistoryToText(History);
}
else
{
ChatHistory.Message lastMessage = history.Messages.Last();
prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content));
}
}
else
{
prompt = history.Messages.Last().Content;
ChatHistory.Message lastMessage = history.Messages.Last();
prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content));
}

await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
yield return textToken;
}
}

private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Console.ForegroundColor = ConsoleColor.Gray;
Console.WriteLine(prompt);

var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
{
yield return item;
yield return textToken;
}
}
}

Loading…
Cancel
Save