Browse Source

Use full history only when the ChatSession runs the first time

tags/v0.8.0
Philipp Bauer 2 years ago
parent
commit
6ea40d1546
1 changed files with 28 additions and 5 deletions
  1. +28
    -5
      LLama/ChatSession.cs

+ 28
- 5
LLama/ChatSession.cs View File

@@ -1,11 +1,14 @@
using LLama.Abstractions; using LLama.Abstractions;
using LLama.Common; using LLama.Common;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using static LLama.InteractiveExecutor;


namespace LLama namespace LLama
{ {
@@ -151,11 +154,17 @@ namespace LLama


History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));


string internalPrompt = HistoryTransform.HistoryToText(History);
if (_executor is InteractiveExecutor executor)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: prompt;
}


StringBuilder sb = new(); StringBuilder sb = new();


await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{ {
yield return result; yield return result;
sb.Append(result); sb.Append(result);
@@ -190,14 +199,28 @@ namespace LLama
/// <returns></returns> /// <returns></returns>
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
var prompt = HistoryTransform.HistoryToText(history);
if (history.Messages.Count == 0)
{
throw new ArgumentException("History must contain at least one message.");
}


StringBuilder sb = new();
string prompt;
if (_executor is InteractiveExecutor executor)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();

prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: history.Messages.Last().Content;
}
else
{
prompt = history.Messages.Last().Content;
}


await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{ {
yield return result; yield return result;
sb.Append(result);
} }
} }




Loading…
Cancel
Save