diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 457e7e48..358d70c3 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -95,11 +95,11 @@ namespace LLama Directory.CreateDirectory(path); } _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); - if(Executor is StatelessExecutor) + if (Executor is StatelessExecutor) { } - else if(Executor is StatefulExecutorBase statefulExecutor) + else if (Executor is StatefulExecutorBase statefulExecutor) { statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); } @@ -135,30 +135,54 @@ namespace LLama } /// - /// Get the response from the LLama model. Note that prompt could not only be the preset words, - /// but also the question you want to ask. + /// Generates a response for a given user prompt and manages history state for the user. + /// This will always pass the whole history to the model. Don't pass a whole history + /// to this method as the user prompt will be appended to the history of the current session. + /// If more control is needed, use the other overload of this method that accepts a ChatHistory object. /// /// /// /// - /// + /// Returns generated tokens of the assistant message. public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - foreach(var inputTransform in InputTransformPipeline) + foreach (var inputTransform in InputTransformPipeline) prompt = inputTransform.Transform(prompt); - - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); + + History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); + + string internalPrompt = HistoryTransform.HistoryToText(History); + StringBuilder sb = new(); - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + + await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); } - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); + + string assistantMessage = sb.ToString(); + + // Remove end tokens from the assistant message + // if defined in inferenceParams.AntiPrompts. + // We only want the response that was generated and not tokens + // that are delimiting the beginning or end of the response. + if (inferenceParams?.AntiPrompts != null) + { + foreach (var stopToken in inferenceParams.AntiPrompts) + { + assistantMessage = assistantMessage.Replace(stopToken, ""); + } + } + + History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage)); } /// - /// Get the response from the LLama model with chat histories. + /// Generates a response for a given chat history. This method does not manage history state for the user. + /// If you want to e.g. truncate the history of a session to fit into the model's context window, + /// use this method and pass the truncated history to it. If you don't need this control, use the other + /// overload of this method that accepts a user prompt instead. /// /// /// @@ -167,14 +191,14 @@ namespace LLama public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var prompt = HistoryTransform.HistoryToText(history); - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); + StringBuilder sb = new(); + await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); } - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)