using LLama.Abstractions; using LLama.Common; using System.Collections.Generic; using System.IO; using System.Runtime.CompilerServices; using System.Text; using System.Threading; namespace LLama { /// /// The main chat session class. /// public class ChatSession { private ILLamaExecutor _executor; private ChatHistory _history; private static readonly string _executorStateFilename = "ExecutorState.json"; private static readonly string _modelStateFilename = "ModelState.st"; /// /// The executor for this session. /// public ILLamaExecutor Executor => _executor; /// /// The chat history for this session. /// public ChatHistory History => _history; /// /// The history transform used in this session. /// public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The input transform pipeline used in this session. /// public List InputTransformPipeline { get; set; } = new(); /// /// The output transform used in this session. /// public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); /// /// /// /// The executor for this session public ChatSession(ILLamaExecutor executor) { _executor = executor; _history = new ChatHistory(); } /// /// Use a custom history transform. /// /// /// public ChatSession WithHistoryTransform(IHistoryTransform transform) { HistoryTransform = transform; return this; } /// /// Add a text transform to the input transform pipeline. /// /// /// public ChatSession AddInputTransform(ITextTransform transform) { InputTransformPipeline.Add(transform); return this; } /// /// Use a custom output transform. /// /// /// public ChatSession WithOutputTransform(ITextStreamTransform transform) { OutputTransform = transform; return this; } /// /// /// /// The directory name to save the session. If the directory does not exist, a new directory will be created. public virtual void SaveSession(string path) { if (!Directory.Exists(path)) { Directory.CreateDirectory(path); } _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); if(Executor is StatelessExecutor) { } else if(Executor is StatefulExecutorBase statefulExecutor) { statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); } else { throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); } } /// /// /// /// The directory name to load the session. public virtual void LoadSession(string path) { if (!Directory.Exists(path)) { throw new FileNotFoundException($"Directory {path} does not exist."); } _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); if (Executor is StatelessExecutor) { } else if (Executor is StatefulExecutorBase statefulExecutor) { statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); } else { throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); } } /// /// Get the response from the LLama model with chat histories. /// /// /// /// /// public IEnumerable Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { var prompt = HistoryTransform.HistoryToText(history); History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); StringBuilder sb = new(); foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); } History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } /// /// Get the response from the LLama model. Note that prompt could not only be the preset words, /// but also the question you want to ask. /// /// /// /// /// public IEnumerable Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { foreach(var inputTransform in InputTransformPipeline) { prompt = inputTransform.Transform(prompt); } History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); StringBuilder sb = new(); foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); } History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } /// /// Get the response from the LLama model with chat histories. /// /// /// /// /// public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var prompt = HistoryTransform.HistoryToText(history); History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); StringBuilder sb = new(); await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); } History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } /// /// Get the response from the LLama model with chat histories asynchronously. /// /// /// /// /// public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var inputTransform in InputTransformPipeline) { prompt = inputTransform.Transform(prompt); } History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); StringBuilder sb = new(); await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); } History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } private IEnumerable ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { var results = _executor.Infer(prompt, inferenceParams, cancellationToken); return OutputTransform.Transform(results); } private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); await foreach (var item in OutputTransform.TransformAsync(results)) { yield return item; } } } }