using LLama.Abstractions;
using LLama.Common;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
///
/// The main chat session class.
///
public class ChatSession
{
private readonly ILLamaExecutor _executor;
private readonly ChatHistory _history;
private const string _executorStateFilename = "ExecutorState.json";
private const string _modelStateFilename = "ModelState.st";
///
/// The executor for this session.
///
public ILLamaExecutor Executor => _executor;
///
/// The chat history for this session.
///
public ChatHistory History => _history;
///
/// The history transform used in this session.
///
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
///
/// The input transform pipeline used in this session.
///
public List InputTransformPipeline { get; set; } = new();
///
/// The output transform used in this session.
///
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
///
///
///
/// The executor for this session
public ChatSession(ILLamaExecutor executor)
{
_executor = executor;
_history = new ChatHistory();
}
///
/// Use a custom history transform.
///
///
///
public ChatSession WithHistoryTransform(IHistoryTransform transform)
{
HistoryTransform = transform;
return this;
}
///
/// Add a text transform to the input transform pipeline.
///
///
///
public ChatSession AddInputTransform(ITextTransform transform)
{
InputTransformPipeline.Add(transform);
return this;
}
///
/// Use a custom output transform.
///
///
///
public ChatSession WithOutputTransform(ITextStreamTransform transform)
{
OutputTransform = transform;
return this;
}
///
///
///
/// The directory name to save the session. If the directory does not exist, a new directory will be created.
public virtual void SaveSession(string path)
{
if (!Directory.Exists(path))
{
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{
}
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
}
///
///
///
/// The directory name to load the session.
public virtual void LoadSession(string path)
{
if (!Directory.Exists(path))
{
throw new FileNotFoundException($"Directory {path} does not exist.");
}
_executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{
}
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
}
///
/// Generates a response for a given user prompt and manages history state for the user.
/// This will always pass the whole history to the model. Don't pass a whole history
/// to this method as the user prompt will be appended to the history of the current session.
/// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
///
///
///
///
/// Returns generated tokens of the assistant message.
public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
string internalPrompt = HistoryTransform.HistoryToText(History);
StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
string assistantMessage = sb.ToString();
// Remove end tokens from the assistant message
// if defined in inferenceParams.AntiPrompts.
// We only want the response that was generated and not tokens
// that are delimiting the beginning or end of the response.
if (inferenceParams?.AntiPrompts != null)
{
foreach (var stopToken in inferenceParams.AntiPrompts)
{
assistantMessage = assistantMessage.Replace(stopToken, "");
}
}
History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
}
///
/// Generates a response for a given chat history. This method does not manage history state for the user.
/// If you want to e.g. truncate the history of a session to fit into the model's context window,
/// use this method and pass the truncated history to it. If you don't need this control, use the other
/// overload of this method that accepts a user prompt instead.
///
///
///
///
///
public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
}
private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
{
yield return item;
}
}
}
}