diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs
index 0f642896..b25563aa 100644
--- a/LLama.Examples/ExampleRunner.cs
+++ b/LLama.Examples/ExampleRunner.cs
@@ -26,6 +26,7 @@ public class ExampleRunner
{ "Semantic Kernel: Prompt", SemanticKernelPrompt.Run },
{ "Semantic Kernel: Chat", SemanticKernelChat.Run },
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
+ { "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
diff --git a/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
new file mode 100644
index 00000000..af0dea52
--- /dev/null
+++ b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
@@ -0,0 +1,108 @@
+using LLama.Batched;
+using LLama.Common;
+using LLama.Native;
+using LLama.Sampling;
+using Spectre.Console;
+
+namespace LLama.Examples.Examples;
+
+///
+/// This demonstrates generating multiple replies to the same prompt, with a shared cache
+///
+public class BatchedExecutorSaveAndLoad
+{
+ private const int n_len = 18;
+
+ public static async Task Run()
+ {
+ string modelPath = UserSettings.GetModelPath();
+
+ var parameters = new ModelParams(modelPath);
+ using var model = LLamaWeights.LoadFromFile(parameters);
+
+ var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
+
+ // Create an executor that can evaluate a batch of conversations together
+ using var executor = new BatchedExecutor(model, parameters);
+
+ // Print some info
+ var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
+ Console.WriteLine($"Created executor with model: {name}");
+
+ // Create a conversation
+ var conversation = executor.Create();
+ conversation.Prompt(prompt);
+
+ // Run inference loop
+ var decoder = new StreamingTokenDecoder(executor.Context);
+ var sampler = new DefaultSamplingPipeline();
+ var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
+
+ // Can't save a conversation while RequiresInference is true
+ if (conversation.RequiresInference)
+ await executor.Infer();
+
+ // Save this conversation to a file and dispose it
+ conversation.Save("demo_conversation.state");
+ conversation.Dispose();
+ AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes");
+
+ // Now create a new conversation by loading that state
+ conversation = executor.Load("demo_conversation.state");
+ AnsiConsole.WriteLine("Loaded state");
+
+ // Prompt it again with the last token, so we can continue generating
+ conversation.Rewind(1);
+ conversation.Prompt(lastToken);
+
+ // Continue generating text
+ lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
+
+ // Can't save a conversation while RequiresInference is true
+ if (conversation.RequiresInference)
+ await executor.Infer();
+
+ // Save the conversation again, this time into system memory
+ using (var state = conversation.Save())
+ {
+ conversation.Dispose();
+ AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes");
+
+ // Now create a new conversation by loading that state
+ conversation = executor.Load("demo_conversation.state");
+ AnsiConsole.WriteLine("Loaded state");
+ }
+
+ // Prompt it again with the last token, so we can continue generating
+ conversation.Rewind(1);
+ conversation.Prompt(lastToken);
+
+ // Continue generating text
+ await GenerateTokens(executor, conversation, sampler, decoder, n_len);
+
+ // Display final ouput
+ AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]");
+ }
+
+ private static async Task GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15)
+ {
+ var token = (LLamaToken)0;
+
+ for (var i = 0; i < count; i++)
+ {
+ // Run inference
+ await executor.Infer();
+
+ // Use sampling pipeline to pick a token
+ token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan.Empty);
+
+ // Add it to the decoder, so it can be converted into text later
+ decoder.Add(token);
+
+ // Prompt the conversation with the token
+ conversation.Prompt(token);
+ }
+
+ return token;
+ }
+}
\ No newline at end of file
diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs
index 5d637c8e..07389e6e 100644
--- a/LLama/Batched/BatchedExecutor.cs
+++ b/LLama/Batched/BatchedExecutor.cs
@@ -84,6 +84,39 @@ public sealed class BatchedExecutor
return new Conversation(this, GetNextSequenceId());
}
+ ///
+ /// Load a conversation that was previously saved to a file. Once loaded the conversation will
+ /// need to be prompted.
+ ///
+ ///
+ ///
+ ///
+ public Conversation Load(string filepath)
+ {
+ if (IsDisposed)
+ throw new ObjectDisposedException(nameof(BatchedExecutor));
+
+ var conversation = Create();
+ conversation.Load(filepath);
+ return conversation;
+ }
+
+ ///
+ /// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
+ ///
+ ///
+ ///
+ ///
+ public Conversation Load(Conversation.State state)
+ {
+ if (IsDisposed)
+ throw new ObjectDisposedException(nameof(BatchedExecutor));
+
+ var conversation = Create();
+ conversation.Load(state);
+ return conversation;
+ }
+
///
/// Run inference for all conversations in the batch which have pending tokens.
///
diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs
index cfdb0a1f..2da3da7c 100644
--- a/LLama/Batched/Conversation.cs
+++ b/LLama/Batched/Conversation.cs
@@ -2,6 +2,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
+using System.Text.Json;
using LLama.Native;
namespace LLama.Batched;
@@ -14,7 +15,7 @@ public sealed class Conversation
{
private ulong _requiredEpoch;
private LLamaPos _end;
- private int _batchIndex;
+ private int _batchSampleIndex;
private bool _disposed;
private bool _forked;
@@ -107,7 +108,7 @@ public sealed class Conversation
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
- _batchIndex = _batchIndex,
+ _batchSampleIndex = _batchSampleIndex,
_forked = true,
_end = _end,
@@ -140,7 +141,7 @@ public sealed class Conversation
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();
- var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
+ var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);
// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
@@ -220,7 +221,7 @@ public sealed class Conversation
// Add the prompt to the batch
for (var i = 0; i < tokens.Length; i++)
- _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
+ _batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
@@ -350,4 +351,168 @@ public sealed class Conversation
/// The new end token position
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
#endregion
+
+ #region save/load
+ private void AssertCanLoad()
+ {
+ AssertNotDisposed();
+ if (_end.Value > 0)
+ throw new InvalidOperationException("Cannot load into a non-empty conversation");
+ }
+
+ private void AssertCanSave()
+ {
+ AssertNotDisposed();
+ if (RequiresInference)
+ throw new CannotSaveWhileRequiresInferenceException();
+ }
+
+
+ ///
+ /// Save the complete state of this conversation to a file. if the file already exists it will be overwritten.
+ ///
+ ///
+ ///
+ public void Save(string filepath)
+ {
+ AssertCanSave();
+
+ // Prepare extra state to put into file header
+ var state = GetState();
+ var bytes = JsonSerializer.SerializeToUtf8Bytes(state);
+
+ // Save extra state along with the KV cache
+ Executor.Context.SaveState(filepath, ConversationId, bytes);
+ }
+
+ ///
+ /// Save the complete state of this conversation in system memory.
+ ///
+ ///
+ public State Save()
+ {
+ AssertCanSave();
+
+ return new PrivateState(
+ Executor.Context.GetState(ConversationId),
+ GetState()
+ );
+ }
+
+
+ ///
+ /// Load state from a file
+ /// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
+ ///
+ ///
+ ///
+ internal void Load(string filepath)
+ {
+ AssertCanLoad();
+
+ // Load the state from file into the KV cache
+ Executor.Context.LoadState(filepath, ConversationId, out var header);
+
+ // deserialize the extra state in the file header
+ var state = JsonSerializer.Deserialize(header);
+ if (state == null)
+ {
+ Dispose();
+ throw new InvalidOperationException("Failed to deserialize - deserialized header state was null");
+ }
+
+ Load(state);
+ }
+
+ ///
+ /// Load state from a previously saved state.
+ /// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
+ ///
+ ///
+ internal void Load(State state)
+ {
+ AssertCanLoad();
+
+ // There is only one class that extends State and it is PrivateState, so this cast is safe.
+ var priv = (PrivateState)state;
+
+ // Load the state from file into the KV cache
+ Executor.Context.LoadState(priv.SequenceState, ConversationId);
+
+ Load(priv.ConversationState);
+ }
+
+
+ private void Load(SerializableConversationState state)
+ {
+ if (state.Version != 1)
+ throw new InvalidOperationException("Failed to deserialize - mismatched version number");
+
+ // Load extra conversation state
+ _end = state.TokenCount;
+ }
+
+ private SerializableConversationState GetState()
+ {
+ return new SerializableConversationState(
+ Version: 1,
+ TokenCount: TokenCount
+ );
+ }
+
+
+ private record SerializableConversationState(int Version, int TokenCount);
+
+ private sealed class PrivateState
+ : State
+ {
+ public readonly LLamaContext.SequenceState SequenceState;
+ public readonly SerializableConversationState ConversationState;
+
+ public override ulong Size => SequenceState.Size;
+
+ public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState)
+ {
+ SequenceState = sequenceState;
+ ConversationState = conversationState;
+ }
+
+ ///
+ public override void Dispose()
+ {
+ if (IsDisposed)
+ throw new ObjectDisposedException(nameof(State));
+ IsDisposed = true;
+
+ SequenceState.Dispose();
+ }
+ }
+
+ ///
+ /// In memory saved state of a
+ ///
+ public abstract class State
+ : IDisposable
+ {
+ ///
+ /// Indicates if this state has been disposed
+ ///
+ public bool IsDisposed { get; protected set; }
+
+ ///
+ /// Get the size in bytes of this state object
+ ///
+ public abstract ulong Size { get; }
+
+ ///
+ public abstract void Dispose();
+
+ ///
+ /// Internal constructor prevent anyone outside of LLamaSharp extending this class
+ ///
+ internal State()
+ {
+ }
+ }
+ #endregion
}
\ No newline at end of file
diff --git a/LLama/Batched/Exceptions.cs b/LLama/Batched/Exceptions.cs
index b025202b..8e225bda 100644
--- a/LLama/Batched/Exceptions.cs
+++ b/LLama/Batched/Exceptions.cs
@@ -57,25 +57,27 @@ public class CannotSampleRequiresPromptException
}
///
-/// This exception is thrown when is called when = true
+/// This exception is thrown when is called when = true
///
-public class CannotForkWhileRequiresInferenceException
+public class CannotModifyWhileRequiresInferenceException
: ExperimentalBatchedExecutorException
{
- internal CannotForkWhileRequiresInferenceException()
- : base("Cannot `Fork()` a conversation while RequiresInference is true")
+ internal CannotModifyWhileRequiresInferenceException()
+ : base("Cannot `Modify()` a conversation while RequiresInference is true")
{
}
}
///
-/// This exception is thrown when is called when = true
+/// This exception is thrown when "Save()" is called on a which has
+/// already been prompted and before "Infer()" has been called.
+/// .
///
-public class CannotModifyWhileRequiresInferenceException
+public class CannotSaveWhileRequiresInferenceException
: ExperimentalBatchedExecutorException
{
- internal CannotModifyWhileRequiresInferenceException()
- : base("Cannot `Modify()` a conversation while RequiresInference is true")
+ internal CannotSaveWhileRequiresInferenceException()
+ : base("Must call `Infer()` before saving this Conversation")
{
}
}
\ No newline at end of file
diff --git a/LLama/Batched/LLamaContextExtensions.cs b/LLama/Batched/LLamaContextExtensions.cs
new file mode 100644
index 00000000..9355301a
--- /dev/null
+++ b/LLama/Batched/LLamaContextExtensions.cs
@@ -0,0 +1,117 @@
+using System;
+using System.Buffers.Binary;
+using System.IO;
+using System.IO.MemoryMappedFiles;
+using LLama.Native;
+
+namespace LLama.Batched;
+
+internal static class LLamaContextExtensions
+{
+ private const uint FileHeaderMagic = 3430400180;
+
+ ///
+ /// Save the state of a particular sequence to specified path. Also save some extra data which will be returned when loading.
+ /// Data saved with this method must be saved with
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal static void SaveState(this LLamaContext context, string filename, LLamaSeqId sequence, ReadOnlySpan header)
+ {
+ // Delete that file before overwriting it
+ if (File.Exists(filename))
+ File.Delete(filename);
+
+ // Estimate size of state to write to disk, this is always equal to or greater than the actual size
+ var estimatedStateSize = checked((long)context.NativeHandle.GetStateSize(sequence));
+
+ // Space for "extra" byte plus a 8 byte header
+ var prefixSize = header.Length + 8;
+
+ // Add enough space for the "extra" data and a 6 byte header
+ var totalFileSize = prefixSize + estimatedStateSize;
+
+ // Map the file and write the bytes directly to it.
+ long writtenBytes = 0;
+ using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, totalFileSize))
+ {
+ using (var view = file.CreateViewAccessor(0, totalFileSize))
+ {
+ unsafe
+ {
+ byte* ptr = null;
+ view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
+ try
+ {
+ // Write prefix data
+ BinaryPrimitives.WriteUInt32BigEndian(new Span(ptr + writtenBytes, 4), FileHeaderMagic);
+ writtenBytes += 4;
+ BinaryPrimitives.WriteUInt32BigEndian(new Span(ptr + writtenBytes, 4), (uint)header.Length);
+ writtenBytes += 4;
+ header.CopyTo(new Span(ptr + writtenBytes, header.Length));
+ writtenBytes += header.Length;
+
+ // Write state data
+ writtenBytes += (long)context.NativeHandle.GetState(ptr + writtenBytes, (ulong)estimatedStateSize, sequence);
+ }
+ finally
+ {
+ view.SafeMemoryMappedViewHandle.ReleasePointer();
+ }
+ }
+ }
+ }
+
+ // Truncate the file to the actual size of data that was written
+ using (var fileStream = new FileStream(filename, FileMode.Open))
+ fileStream.SetLength(writtenBytes);
+ }
+
+ ///
+ /// Load the state from the specified path into a particular sequence. Also reading header data. Must only be used with
+ /// data previously saved with
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal static void LoadState(this LLamaContext context, string filename, LLamaSeqId sequence, out byte[] header)
+ {
+ // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
+ using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null))
+ using (var view = file.CreateViewAccessor())
+ {
+ unsafe
+ {
+ byte* ptr = null;
+ view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
+ try
+ {
+ var readBytes = 0;
+
+ // Read header
+ var magic = BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan(ptr + readBytes, 4));
+ readBytes += 4;
+ if (magic != FileHeaderMagic)
+ throw new InvalidOperationException("Invalid file header");
+
+ var headerLength = checked((int)BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan(ptr + readBytes, 4)));
+ readBytes += 4;
+
+ header = new byte[headerLength];
+ new Span(ptr + readBytes, headerLength).CopyTo(header);
+ readBytes += headerLength;
+
+ context.NativeHandle.SetState(ptr + readBytes, sequence);
+ }
+ finally
+ {
+ view.SafeMemoryMappedViewHandle.ReleasePointer();
+ }
+ }
+ }
+ }
+}
\ No newline at end of file