* Added the ability to save and load individual conversations in a batched executor. - New example - Added `BatchedExecutor.Load(filepath)` method - Added `Conversation.Save(filepath)` method - Added new (currently internal) `SaveState`/`LoadState` methods in LLamaContext which can stash some extra binary data in the header * Added ability to save/load a `Conversation` to an in-memory state, instead of to file. * Moved the new save/load methods out to an extension class specifically for the batched executor. * Removed unnecessary spacespull/640/merge
| @@ -26,6 +26,7 @@ public class ExampleRunner | |||||
| { "Semantic Kernel: Prompt", SemanticKernelPrompt.Run }, | { "Semantic Kernel: Prompt", SemanticKernelPrompt.Run }, | ||||
| { "Semantic Kernel: Chat", SemanticKernelChat.Run }, | { "Semantic Kernel: Chat", SemanticKernelChat.Run }, | ||||
| { "Semantic Kernel: Store", SemanticKernelMemory.Run }, | { "Semantic Kernel: Store", SemanticKernelMemory.Run }, | ||||
| { "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run }, | |||||
| { "Batched Executor: Fork", BatchedExecutorFork.Run }, | { "Batched Executor: Fork", BatchedExecutorFork.Run }, | ||||
| { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, | { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, | ||||
| { "Batched Executor: Guidance", BatchedExecutorGuidance.Run }, | { "Batched Executor: Guidance", BatchedExecutorGuidance.Run }, | ||||
| @@ -0,0 +1,108 @@ | |||||
| using LLama.Batched; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| using LLama.Sampling; | |||||
| using Spectre.Console; | |||||
| namespace LLama.Examples.Examples; | |||||
| /// <summary> | |||||
| /// This demonstrates generating multiple replies to the same prompt, with a shared cache | |||||
| /// </summary> | |||||
| public class BatchedExecutorSaveAndLoad | |||||
| { | |||||
| private const int n_len = 18; | |||||
| public static async Task Run() | |||||
| { | |||||
| string modelPath = UserSettings.GetModelPath(); | |||||
| var parameters = new ModelParams(modelPath); | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | |||||
| // Create an executor that can evaluate a batch of conversations together | |||||
| using var executor = new BatchedExecutor(model, parameters); | |||||
| // Print some info | |||||
| var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); | |||||
| Console.WriteLine($"Created executor with model: {name}"); | |||||
| // Create a conversation | |||||
| var conversation = executor.Create(); | |||||
| conversation.Prompt(prompt); | |||||
| // Run inference loop | |||||
| var decoder = new StreamingTokenDecoder(executor.Context); | |||||
| var sampler = new DefaultSamplingPipeline(); | |||||
| var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len); | |||||
| // Can't save a conversation while RequiresInference is true | |||||
| if (conversation.RequiresInference) | |||||
| await executor.Infer(); | |||||
| // Save this conversation to a file and dispose it | |||||
| conversation.Save("demo_conversation.state"); | |||||
| conversation.Dispose(); | |||||
| AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes"); | |||||
| // Now create a new conversation by loading that state | |||||
| conversation = executor.Load("demo_conversation.state"); | |||||
| AnsiConsole.WriteLine("Loaded state"); | |||||
| // Prompt it again with the last token, so we can continue generating | |||||
| conversation.Rewind(1); | |||||
| conversation.Prompt(lastToken); | |||||
| // Continue generating text | |||||
| lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len); | |||||
| // Can't save a conversation while RequiresInference is true | |||||
| if (conversation.RequiresInference) | |||||
| await executor.Infer(); | |||||
| // Save the conversation again, this time into system memory | |||||
| using (var state = conversation.Save()) | |||||
| { | |||||
| conversation.Dispose(); | |||||
| AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes"); | |||||
| // Now create a new conversation by loading that state | |||||
| conversation = executor.Load("demo_conversation.state"); | |||||
| AnsiConsole.WriteLine("Loaded state"); | |||||
| } | |||||
| // Prompt it again with the last token, so we can continue generating | |||||
| conversation.Rewind(1); | |||||
| conversation.Prompt(lastToken); | |||||
| // Continue generating text | |||||
| await GenerateTokens(executor, conversation, sampler, decoder, n_len); | |||||
| // Display final ouput | |||||
| AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]"); | |||||
| } | |||||
| private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15) | |||||
| { | |||||
| var token = (LLamaToken)0; | |||||
| for (var i = 0; i < count; i++) | |||||
| { | |||||
| // Run inference | |||||
| await executor.Infer(); | |||||
| // Use sampling pipeline to pick a token | |||||
| token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty); | |||||
| // Add it to the decoder, so it can be converted into text later | |||||
| decoder.Add(token); | |||||
| // Prompt the conversation with the token | |||||
| conversation.Prompt(token); | |||||
| } | |||||
| return token; | |||||
| } | |||||
| } | |||||
| @@ -84,6 +84,39 @@ public sealed class BatchedExecutor | |||||
| return new Conversation(this, GetNextSequenceId()); | return new Conversation(this, GetNextSequenceId()); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Load a conversation that was previously saved to a file. Once loaded the conversation will | |||||
| /// need to be prompted. | |||||
| /// </summary> | |||||
| /// <param name="filepath"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| public Conversation Load(string filepath) | |||||
| { | |||||
| if (IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(BatchedExecutor)); | |||||
| var conversation = Create(); | |||||
| conversation.Load(filepath); | |||||
| return conversation; | |||||
| } | |||||
| /// <summary> | |||||
| /// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted. | |||||
| /// </summary> | |||||
| /// <param name="state"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| public Conversation Load(Conversation.State state) | |||||
| { | |||||
| if (IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(BatchedExecutor)); | |||||
| var conversation = Create(); | |||||
| conversation.Load(state); | |||||
| return conversation; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Run inference for all conversations in the batch which have pending tokens. | /// Run inference for all conversations in the batch which have pending tokens. | ||||
| /// | /// | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text.Json; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| namespace LLama.Batched; | namespace LLama.Batched; | ||||
| @@ -14,7 +15,7 @@ public sealed class Conversation | |||||
| { | { | ||||
| private ulong _requiredEpoch; | private ulong _requiredEpoch; | ||||
| private LLamaPos _end; | private LLamaPos _end; | ||||
| private int _batchIndex; | |||||
| private int _batchSampleIndex; | |||||
| private bool _disposed; | private bool _disposed; | ||||
| private bool _forked; | private bool _forked; | ||||
| @@ -107,7 +108,7 @@ public sealed class Conversation | |||||
| // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures | // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures | ||||
| // they both copy the logits before the next sampling run, to fix this issue. | // they both copy the logits before the next sampling run, to fix this issue. | ||||
| _requiredEpoch = _requiredEpoch, | _requiredEpoch = _requiredEpoch, | ||||
| _batchIndex = _batchIndex, | |||||
| _batchSampleIndex = _batchSampleIndex, | |||||
| _forked = true, | _forked = true, | ||||
| _end = _end, | _end = _end, | ||||
| @@ -140,7 +141,7 @@ public sealed class Conversation | |||||
| if (_requiredEpoch > Executor.Epoch) | if (_requiredEpoch > Executor.Epoch) | ||||
| throw new CannotSampleRequiresInferenceException(); | throw new CannotSampleRequiresInferenceException(); | ||||
| var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); | |||||
| var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex); | |||||
| // If necessary copy the span, to protect it from modification. This is only done when | // If necessary copy the span, to protect it from modification. This is only done when | ||||
| // this conversation has been forked in this epoch. | // this conversation has been forked in this epoch. | ||||
| @@ -220,7 +221,7 @@ public sealed class Conversation | |||||
| // Add the prompt to the batch | // Add the prompt to the batch | ||||
| for (var i = 0; i < tokens.Length; i++) | for (var i = 0; i < tokens.Length; i++) | ||||
| _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); | |||||
| _batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); | |||||
| // Mark this conversation as needing inference/sampling | // Mark this conversation as needing inference/sampling | ||||
| _requiredEpoch = Executor.Epoch + 1; | _requiredEpoch = Executor.Epoch + 1; | ||||
| @@ -350,4 +351,168 @@ public sealed class Conversation | |||||
| /// <returns>The new end token position</returns> | /// <returns>The new end token position</returns> | ||||
| public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); | public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); | ||||
| #endregion | #endregion | ||||
| #region save/load | |||||
| private void AssertCanLoad() | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (_end.Value > 0) | |||||
| throw new InvalidOperationException("Cannot load into a non-empty conversation"); | |||||
| } | |||||
| private void AssertCanSave() | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (RequiresInference) | |||||
| throw new CannotSaveWhileRequiresInferenceException(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Save the complete state of this conversation to a file. if the file already exists it will be overwritten. | |||||
| /// </summary> | |||||
| /// <param name="filepath"></param> | |||||
| /// <exception cref="CannotSaveWhileRequiresInferenceException"></exception> | |||||
| public void Save(string filepath) | |||||
| { | |||||
| AssertCanSave(); | |||||
| // Prepare extra state to put into file header | |||||
| var state = GetState(); | |||||
| var bytes = JsonSerializer.SerializeToUtf8Bytes(state); | |||||
| // Save extra state along with the KV cache | |||||
| Executor.Context.SaveState(filepath, ConversationId, bytes); | |||||
| } | |||||
| /// <summary> | |||||
| /// Save the complete state of this conversation in system memory. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public State Save() | |||||
| { | |||||
| AssertCanSave(); | |||||
| return new PrivateState( | |||||
| Executor.Context.GetState(ConversationId), | |||||
| GetState() | |||||
| ); | |||||
| } | |||||
| /// <summary> | |||||
| /// Load state from a file | |||||
| /// This should only ever be called by the BatchedExecutor, on a newly created conversation object! | |||||
| /// </summary> | |||||
| /// <param name="filepath"></param> | |||||
| /// <exception cref="InvalidOperationException"></exception> | |||||
| internal void Load(string filepath) | |||||
| { | |||||
| AssertCanLoad(); | |||||
| // Load the state from file into the KV cache | |||||
| Executor.Context.LoadState(filepath, ConversationId, out var header); | |||||
| // deserialize the extra state in the file header | |||||
| var state = JsonSerializer.Deserialize<SerializableConversationState>(header); | |||||
| if (state == null) | |||||
| { | |||||
| Dispose(); | |||||
| throw new InvalidOperationException("Failed to deserialize - deserialized header state was null"); | |||||
| } | |||||
| Load(state); | |||||
| } | |||||
| /// <summary> | |||||
| /// Load state from a previously saved state. | |||||
| /// This should only ever be called by the BatchedExecutor, on a newly created conversation object! | |||||
| /// </summary> | |||||
| /// <param name="state"></param> | |||||
| internal void Load(State state) | |||||
| { | |||||
| AssertCanLoad(); | |||||
| // There is only one class that extends State and it is PrivateState, so this cast is safe. | |||||
| var priv = (PrivateState)state; | |||||
| // Load the state from file into the KV cache | |||||
| Executor.Context.LoadState(priv.SequenceState, ConversationId); | |||||
| Load(priv.ConversationState); | |||||
| } | |||||
| private void Load(SerializableConversationState state) | |||||
| { | |||||
| if (state.Version != 1) | |||||
| throw new InvalidOperationException("Failed to deserialize - mismatched version number"); | |||||
| // Load extra conversation state | |||||
| _end = state.TokenCount; | |||||
| } | |||||
| private SerializableConversationState GetState() | |||||
| { | |||||
| return new SerializableConversationState( | |||||
| Version: 1, | |||||
| TokenCount: TokenCount | |||||
| ); | |||||
| } | |||||
| private record SerializableConversationState(int Version, int TokenCount); | |||||
| private sealed class PrivateState | |||||
| : State | |||||
| { | |||||
| public readonly LLamaContext.SequenceState SequenceState; | |||||
| public readonly SerializableConversationState ConversationState; | |||||
| public override ulong Size => SequenceState.Size; | |||||
| public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState) | |||||
| { | |||||
| SequenceState = sequenceState; | |||||
| ConversationState = conversationState; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override void Dispose() | |||||
| { | |||||
| if (IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(State)); | |||||
| IsDisposed = true; | |||||
| SequenceState.Dispose(); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// In memory saved state of a <see cref="Conversation"/> | |||||
| /// </summary> | |||||
| public abstract class State | |||||
| : IDisposable | |||||
| { | |||||
| /// <summary> | |||||
| /// Indicates if this state has been disposed | |||||
| /// </summary> | |||||
| public bool IsDisposed { get; protected set; } | |||||
| /// <summary> | |||||
| /// Get the size in bytes of this state object | |||||
| /// </summary> | |||||
| public abstract ulong Size { get; } | |||||
| /// <inheritdoc /> | |||||
| public abstract void Dispose(); | |||||
| /// <summary> | |||||
| /// Internal constructor prevent anyone outside of LLamaSharp extending this class | |||||
| /// </summary> | |||||
| internal State() | |||||
| { | |||||
| } | |||||
| } | |||||
| #endregion | |||||
| } | } | ||||
| @@ -57,25 +57,27 @@ public class CannotSampleRequiresPromptException | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// </summary> | /// </summary> | ||||
| public class CannotForkWhileRequiresInferenceException | |||||
| public class CannotModifyWhileRequiresInferenceException | |||||
| : ExperimentalBatchedExecutorException | : ExperimentalBatchedExecutorException | ||||
| { | { | ||||
| internal CannotForkWhileRequiresInferenceException() | |||||
| : base("Cannot `Fork()` a conversation while RequiresInference is true") | |||||
| internal CannotModifyWhileRequiresInferenceException() | |||||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | |||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// This exception is thrown when "Save()" is called on a <see cref="Conversation"/> which has | |||||
| /// already been prompted and before "Infer()" has been called. | |||||
| /// <see cref="BatchedExecutor"/>. | |||||
| /// </summary> | /// </summary> | ||||
| public class CannotModifyWhileRequiresInferenceException | |||||
| public class CannotSaveWhileRequiresInferenceException | |||||
| : ExperimentalBatchedExecutorException | : ExperimentalBatchedExecutorException | ||||
| { | { | ||||
| internal CannotModifyWhileRequiresInferenceException() | |||||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | |||||
| internal CannotSaveWhileRequiresInferenceException() | |||||
| : base("Must call `Infer()` before saving this Conversation") | |||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,117 @@ | |||||
| using System; | |||||
| using System.Buffers.Binary; | |||||
| using System.IO; | |||||
| using System.IO.MemoryMappedFiles; | |||||
| using LLama.Native; | |||||
| namespace LLama.Batched; | |||||
| internal static class LLamaContextExtensions | |||||
| { | |||||
| private const uint FileHeaderMagic = 3430400180; | |||||
| /// <summary> | |||||
| /// Save the state of a particular sequence to specified path. Also save some extra data which will be returned when loading. | |||||
| /// Data saved with this method <b>must</b> be saved with <see cref="LoadState(LLamaContext, string, LLamaSeqId, out byte[])"/> | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="filename"></param> | |||||
| /// <param name="sequence"></param> | |||||
| /// <param name="header"></param> | |||||
| internal static void SaveState(this LLamaContext context, string filename, LLamaSeqId sequence, ReadOnlySpan<byte> header) | |||||
| { | |||||
| // Delete that file before overwriting it | |||||
| if (File.Exists(filename)) | |||||
| File.Delete(filename); | |||||
| // Estimate size of state to write to disk, this is always equal to or greater than the actual size | |||||
| var estimatedStateSize = checked((long)context.NativeHandle.GetStateSize(sequence)); | |||||
| // Space for "extra" byte plus a 8 byte header | |||||
| var prefixSize = header.Length + 8; | |||||
| // Add enough space for the "extra" data and a 6 byte header | |||||
| var totalFileSize = prefixSize + estimatedStateSize; | |||||
| // Map the file and write the bytes directly to it. | |||||
| long writtenBytes = 0; | |||||
| using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, totalFileSize)) | |||||
| { | |||||
| using (var view = file.CreateViewAccessor(0, totalFileSize)) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| byte* ptr = null; | |||||
| view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); | |||||
| try | |||||
| { | |||||
| // Write prefix data | |||||
| BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), FileHeaderMagic); | |||||
| writtenBytes += 4; | |||||
| BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), (uint)header.Length); | |||||
| writtenBytes += 4; | |||||
| header.CopyTo(new Span<byte>(ptr + writtenBytes, header.Length)); | |||||
| writtenBytes += header.Length; | |||||
| // Write state data | |||||
| writtenBytes += (long)context.NativeHandle.GetState(ptr + writtenBytes, (ulong)estimatedStateSize, sequence); | |||||
| } | |||||
| finally | |||||
| { | |||||
| view.SafeMemoryMappedViewHandle.ReleasePointer(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // Truncate the file to the actual size of data that was written | |||||
| using (var fileStream = new FileStream(filename, FileMode.Open)) | |||||
| fileStream.SetLength(writtenBytes); | |||||
| } | |||||
| /// <summary> | |||||
| /// Load the state from the specified path into a particular sequence. Also reading header data. Must only be used with | |||||
| /// data previously saved with <see cref="SaveState(LLamaContext, string, LLamaSeqId, ReadOnlySpan{byte})"/> | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="filename"></param> | |||||
| /// <param name="sequence"></param> | |||||
| /// <param name="header"></param> | |||||
| /// <exception cref="InvalidOperationException"></exception> | |||||
| internal static void LoadState(this LLamaContext context, string filename, LLamaSeqId sequence, out byte[] header) | |||||
| { | |||||
| // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from | |||||
| using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) | |||||
| using (var view = file.CreateViewAccessor()) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| byte* ptr = null; | |||||
| view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); | |||||
| try | |||||
| { | |||||
| var readBytes = 0; | |||||
| // Read header | |||||
| var magic = BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan<byte>(ptr + readBytes, 4)); | |||||
| readBytes += 4; | |||||
| if (magic != FileHeaderMagic) | |||||
| throw new InvalidOperationException("Invalid file header"); | |||||
| var headerLength = checked((int)BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan<byte>(ptr + readBytes, 4))); | |||||
| readBytes += 4; | |||||
| header = new byte[headerLength]; | |||||
| new Span<byte>(ptr + readBytes, headerLength).CopyTo(header); | |||||
| readBytes += headerLength; | |||||
| context.NativeHandle.SetState(ptr + readBytes, sequence); | |||||
| } | |||||
| finally | |||||
| { | |||||
| view.SafeMemoryMappedViewHandle.ReleasePointer(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||