using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;
namespace LLama.Examples.Examples;
///
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
///
public class BatchedExecutorSaveAndLoad
{
private const int n_len = 18;
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Create a conversation
var conversation = executor.Create();
conversation.Prompt(prompt);
// Run inference loop
var decoder = new StreamingTokenDecoder(executor.Context);
var sampler = new DefaultSamplingPipeline();
var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
// Can't save a conversation while RequiresInference is true
if (conversation.RequiresInference)
await executor.Infer();
// Save this conversation to a file and dispose it
conversation.Save("demo_conversation.state");
conversation.Dispose();
AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes");
// Now create a new conversation by loading that state
conversation = executor.Load("demo_conversation.state");
AnsiConsole.WriteLine("Loaded state");
// Prompt it again with the last token, so we can continue generating
conversation.Rewind(1);
conversation.Prompt(lastToken);
// Continue generating text
lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
// Can't save a conversation while RequiresInference is true
if (conversation.RequiresInference)
await executor.Infer();
// Save the conversation again, this time into system memory
using (var state = conversation.Save())
{
conversation.Dispose();
AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes");
// Now create a new conversation by loading that state
conversation = executor.Load("demo_conversation.state");
AnsiConsole.WriteLine("Loaded state");
}
// Prompt it again with the last token, so we can continue generating
conversation.Rewind(1);
conversation.Prompt(lastToken);
// Continue generating text
await GenerateTokens(executor, conversation, sampler, decoder, n_len);
// Display final ouput
AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]");
}
private static async Task GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15)
{
var token = (LLamaToken)0;
for (var i = 0; i < count; i++)
{
// Run inference
await executor.Infer();
// Use sampling pipeline to pick a token
token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan.Empty);
// Add it to the decoder, so it can be converted into text later
decoder.Add(token);
// Prompt the conversation with the token
conversation.Prompt(token);
}
return token;
}
}