diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index 0f642896..b25563aa 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -26,6 +26,7 @@ public class ExampleRunner { "Semantic Kernel: Prompt", SemanticKernelPrompt.Run }, { "Semantic Kernel: Chat", SemanticKernelChat.Run }, { "Semantic Kernel: Store", SemanticKernelMemory.Run }, + { "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run }, { "Batched Executor: Fork", BatchedExecutorFork.Run }, { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, { "Batched Executor: Guidance", BatchedExecutorGuidance.Run }, diff --git a/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs new file mode 100644 index 00000000..af0dea52 --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs @@ -0,0 +1,108 @@ +using LLama.Batched; +using LLama.Common; +using LLama.Native; +using LLama.Sampling; +using Spectre.Console; + +namespace LLama.Examples.Examples; + +/// +/// This demonstrates generating multiple replies to the same prompt, with a shared cache +/// +public class BatchedExecutorSaveAndLoad +{ + private const int n_len = 18; + + public static async Task Run() + { + string modelPath = UserSettings.GetModelPath(); + + var parameters = new ModelParams(modelPath); + using var model = LLamaWeights.LoadFromFile(parameters); + + var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); + + // Create an executor that can evaluate a batch of conversations together + using var executor = new BatchedExecutor(model, parameters); + + // Print some info + var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); + Console.WriteLine($"Created executor with model: {name}"); + + // Create a conversation + var conversation = executor.Create(); + conversation.Prompt(prompt); + + // Run inference loop + var decoder = new StreamingTokenDecoder(executor.Context); + var sampler = new DefaultSamplingPipeline(); + var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len); + + // Can't save a conversation while RequiresInference is true + if (conversation.RequiresInference) + await executor.Infer(); + + // Save this conversation to a file and dispose it + conversation.Save("demo_conversation.state"); + conversation.Dispose(); + AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes"); + + // Now create a new conversation by loading that state + conversation = executor.Load("demo_conversation.state"); + AnsiConsole.WriteLine("Loaded state"); + + // Prompt it again with the last token, so we can continue generating + conversation.Rewind(1); + conversation.Prompt(lastToken); + + // Continue generating text + lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len); + + // Can't save a conversation while RequiresInference is true + if (conversation.RequiresInference) + await executor.Infer(); + + // Save the conversation again, this time into system memory + using (var state = conversation.Save()) + { + conversation.Dispose(); + AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes"); + + // Now create a new conversation by loading that state + conversation = executor.Load("demo_conversation.state"); + AnsiConsole.WriteLine("Loaded state"); + } + + // Prompt it again with the last token, so we can continue generating + conversation.Rewind(1); + conversation.Prompt(lastToken); + + // Continue generating text + await GenerateTokens(executor, conversation, sampler, decoder, n_len); + + // Display final ouput + AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]"); + } + + private static async Task GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15) + { + var token = (LLamaToken)0; + + for (var i = 0; i < count; i++) + { + // Run inference + await executor.Infer(); + + // Use sampling pipeline to pick a token + token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan.Empty); + + // Add it to the decoder, so it can be converted into text later + decoder.Add(token); + + // Prompt the conversation with the token + conversation.Prompt(token); + } + + return token; + } +} \ No newline at end of file diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index 5d637c8e..07389e6e 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -84,6 +84,39 @@ public sealed class BatchedExecutor return new Conversation(this, GetNextSequenceId()); } + /// + /// Load a conversation that was previously saved to a file. Once loaded the conversation will + /// need to be prompted. + /// + /// + /// + /// + public Conversation Load(string filepath) + { + if (IsDisposed) + throw new ObjectDisposedException(nameof(BatchedExecutor)); + + var conversation = Create(); + conversation.Load(filepath); + return conversation; + } + + /// + /// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted. + /// + /// + /// + /// + public Conversation Load(Conversation.State state) + { + if (IsDisposed) + throw new ObjectDisposedException(nameof(BatchedExecutor)); + + var conversation = Create(); + conversation.Load(state); + return conversation; + } + /// /// Run inference for all conversations in the batch which have pending tokens. /// diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index cfdb0a1f..2da3da7c 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -2,6 +2,7 @@ using System.Buffers; using System.Collections.Generic; using System.Runtime.InteropServices; +using System.Text.Json; using LLama.Native; namespace LLama.Batched; @@ -14,7 +15,7 @@ public sealed class Conversation { private ulong _requiredEpoch; private LLamaPos _end; - private int _batchIndex; + private int _batchSampleIndex; private bool _disposed; private bool _forked; @@ -107,7 +108,7 @@ public sealed class Conversation // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures // they both copy the logits before the next sampling run, to fix this issue. _requiredEpoch = _requiredEpoch, - _batchIndex = _batchIndex, + _batchSampleIndex = _batchSampleIndex, _forked = true, _end = _end, @@ -140,7 +141,7 @@ public sealed class Conversation if (_requiredEpoch > Executor.Epoch) throw new CannotSampleRequiresInferenceException(); - var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); + var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex); // If necessary copy the span, to protect it from modification. This is only done when // this conversation has been forked in this epoch. @@ -220,7 +221,7 @@ public sealed class Conversation // Add the prompt to the batch for (var i = 0; i < tokens.Length; i++) - _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); + _batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); // Mark this conversation as needing inference/sampling _requiredEpoch = Executor.Epoch + 1; @@ -350,4 +351,168 @@ public sealed class Conversation /// The new end token position public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); #endregion + + #region save/load + private void AssertCanLoad() + { + AssertNotDisposed(); + if (_end.Value > 0) + throw new InvalidOperationException("Cannot load into a non-empty conversation"); + } + + private void AssertCanSave() + { + AssertNotDisposed(); + if (RequiresInference) + throw new CannotSaveWhileRequiresInferenceException(); + } + + + /// + /// Save the complete state of this conversation to a file. if the file already exists it will be overwritten. + /// + /// + /// + public void Save(string filepath) + { + AssertCanSave(); + + // Prepare extra state to put into file header + var state = GetState(); + var bytes = JsonSerializer.SerializeToUtf8Bytes(state); + + // Save extra state along with the KV cache + Executor.Context.SaveState(filepath, ConversationId, bytes); + } + + /// + /// Save the complete state of this conversation in system memory. + /// + /// + public State Save() + { + AssertCanSave(); + + return new PrivateState( + Executor.Context.GetState(ConversationId), + GetState() + ); + } + + + /// + /// Load state from a file + /// This should only ever be called by the BatchedExecutor, on a newly created conversation object! + /// + /// + /// + internal void Load(string filepath) + { + AssertCanLoad(); + + // Load the state from file into the KV cache + Executor.Context.LoadState(filepath, ConversationId, out var header); + + // deserialize the extra state in the file header + var state = JsonSerializer.Deserialize(header); + if (state == null) + { + Dispose(); + throw new InvalidOperationException("Failed to deserialize - deserialized header state was null"); + } + + Load(state); + } + + /// + /// Load state from a previously saved state. + /// This should only ever be called by the BatchedExecutor, on a newly created conversation object! + /// + /// + internal void Load(State state) + { + AssertCanLoad(); + + // There is only one class that extends State and it is PrivateState, so this cast is safe. + var priv = (PrivateState)state; + + // Load the state from file into the KV cache + Executor.Context.LoadState(priv.SequenceState, ConversationId); + + Load(priv.ConversationState); + } + + + private void Load(SerializableConversationState state) + { + if (state.Version != 1) + throw new InvalidOperationException("Failed to deserialize - mismatched version number"); + + // Load extra conversation state + _end = state.TokenCount; + } + + private SerializableConversationState GetState() + { + return new SerializableConversationState( + Version: 1, + TokenCount: TokenCount + ); + } + + + private record SerializableConversationState(int Version, int TokenCount); + + private sealed class PrivateState + : State + { + public readonly LLamaContext.SequenceState SequenceState; + public readonly SerializableConversationState ConversationState; + + public override ulong Size => SequenceState.Size; + + public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState) + { + SequenceState = sequenceState; + ConversationState = conversationState; + } + + /// + public override void Dispose() + { + if (IsDisposed) + throw new ObjectDisposedException(nameof(State)); + IsDisposed = true; + + SequenceState.Dispose(); + } + } + + /// + /// In memory saved state of a + /// + public abstract class State + : IDisposable + { + /// + /// Indicates if this state has been disposed + /// + public bool IsDisposed { get; protected set; } + + /// + /// Get the size in bytes of this state object + /// + public abstract ulong Size { get; } + + /// + public abstract void Dispose(); + + /// + /// Internal constructor prevent anyone outside of LLamaSharp extending this class + /// + internal State() + { + } + } + #endregion } \ No newline at end of file diff --git a/LLama/Batched/Exceptions.cs b/LLama/Batched/Exceptions.cs index b025202b..8e225bda 100644 --- a/LLama/Batched/Exceptions.cs +++ b/LLama/Batched/Exceptions.cs @@ -57,25 +57,27 @@ public class CannotSampleRequiresPromptException } /// -/// This exception is thrown when is called when = true +/// This exception is thrown when is called when = true /// -public class CannotForkWhileRequiresInferenceException +public class CannotModifyWhileRequiresInferenceException : ExperimentalBatchedExecutorException { - internal CannotForkWhileRequiresInferenceException() - : base("Cannot `Fork()` a conversation while RequiresInference is true") + internal CannotModifyWhileRequiresInferenceException() + : base("Cannot `Modify()` a conversation while RequiresInference is true") { } } /// -/// This exception is thrown when is called when = true +/// This exception is thrown when "Save()" is called on a which has +/// already been prompted and before "Infer()" has been called. +/// . /// -public class CannotModifyWhileRequiresInferenceException +public class CannotSaveWhileRequiresInferenceException : ExperimentalBatchedExecutorException { - internal CannotModifyWhileRequiresInferenceException() - : base("Cannot `Modify()` a conversation while RequiresInference is true") + internal CannotSaveWhileRequiresInferenceException() + : base("Must call `Infer()` before saving this Conversation") { } } \ No newline at end of file diff --git a/LLama/Batched/LLamaContextExtensions.cs b/LLama/Batched/LLamaContextExtensions.cs new file mode 100644 index 00000000..9355301a --- /dev/null +++ b/LLama/Batched/LLamaContextExtensions.cs @@ -0,0 +1,117 @@ +using System; +using System.Buffers.Binary; +using System.IO; +using System.IO.MemoryMappedFiles; +using LLama.Native; + +namespace LLama.Batched; + +internal static class LLamaContextExtensions +{ + private const uint FileHeaderMagic = 3430400180; + + /// + /// Save the state of a particular sequence to specified path. Also save some extra data which will be returned when loading. + /// Data saved with this method must be saved with + /// + /// + /// + /// + /// + internal static void SaveState(this LLamaContext context, string filename, LLamaSeqId sequence, ReadOnlySpan header) + { + // Delete that file before overwriting it + if (File.Exists(filename)) + File.Delete(filename); + + // Estimate size of state to write to disk, this is always equal to or greater than the actual size + var estimatedStateSize = checked((long)context.NativeHandle.GetStateSize(sequence)); + + // Space for "extra" byte plus a 8 byte header + var prefixSize = header.Length + 8; + + // Add enough space for the "extra" data and a 6 byte header + var totalFileSize = prefixSize + estimatedStateSize; + + // Map the file and write the bytes directly to it. + long writtenBytes = 0; + using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, totalFileSize)) + { + using (var view = file.CreateViewAccessor(0, totalFileSize)) + { + unsafe + { + byte* ptr = null; + view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); + try + { + // Write prefix data + BinaryPrimitives.WriteUInt32BigEndian(new Span(ptr + writtenBytes, 4), FileHeaderMagic); + writtenBytes += 4; + BinaryPrimitives.WriteUInt32BigEndian(new Span(ptr + writtenBytes, 4), (uint)header.Length); + writtenBytes += 4; + header.CopyTo(new Span(ptr + writtenBytes, header.Length)); + writtenBytes += header.Length; + + // Write state data + writtenBytes += (long)context.NativeHandle.GetState(ptr + writtenBytes, (ulong)estimatedStateSize, sequence); + } + finally + { + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } + } + } + } + + // Truncate the file to the actual size of data that was written + using (var fileStream = new FileStream(filename, FileMode.Open)) + fileStream.SetLength(writtenBytes); + } + + /// + /// Load the state from the specified path into a particular sequence. Also reading header data. Must only be used with + /// data previously saved with + /// + /// + /// + /// + /// + /// + internal static void LoadState(this LLamaContext context, string filename, LLamaSeqId sequence, out byte[] header) + { + // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from + using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) + using (var view = file.CreateViewAccessor()) + { + unsafe + { + byte* ptr = null; + view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); + try + { + var readBytes = 0; + + // Read header + var magic = BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan(ptr + readBytes, 4)); + readBytes += 4; + if (magic != FileHeaderMagic) + throw new InvalidOperationException("Invalid file header"); + + var headerLength = checked((int)BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan(ptr + readBytes, 4))); + readBytes += 4; + + header = new byte[headerLength]; + new Span(ptr + readBytes, headerLength).CopyTo(header); + readBytes += headerLength; + + context.NativeHandle.SetState(ptr + readBytes, sequence); + } + finally + { + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } + } + } + } +} \ No newline at end of file