You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorSaveAndLoad.cs 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. using Spectre.Console;
  6. namespace LLama.Examples.Examples;
  7. /// <summary>
  8. /// This demonstrates generating multiple replies to the same prompt, with a shared cache
  9. /// </summary>
  10. public class BatchedExecutorSaveAndLoad
  11. {
  12. private const int n_len = 18;
  13. public static async Task Run()
  14. {
  15. string modelPath = UserSettings.GetModelPath();
  16. var parameters = new ModelParams(modelPath);
  17. using var model = await LLamaWeights.LoadFromFileAsync(parameters);
  18. var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
  19. // Create an executor that can evaluate a batch of conversations together
  20. using var executor = new BatchedExecutor(model, parameters);
  21. // Print some info
  22. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  23. Console.WriteLine($"Created executor with model: {name}");
  24. // Create a conversation
  25. var conversation = executor.Create();
  26. conversation.Prompt(prompt);
  27. // Run inference loop
  28. var decoder = new StreamingTokenDecoder(executor.Context);
  29. var sampler = new DefaultSamplingPipeline();
  30. var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
  31. // Can't save a conversation while RequiresInference is true
  32. if (conversation.RequiresInference)
  33. await executor.Infer();
  34. // Save this conversation to a file and dispose it
  35. conversation.Save("demo_conversation.state");
  36. conversation.Dispose();
  37. AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes");
  38. // Now create a new conversation by loading that state
  39. conversation = executor.Load("demo_conversation.state");
  40. AnsiConsole.WriteLine("Loaded state");
  41. // Prompt it again with the last token, so we can continue generating
  42. conversation.Rewind(1);
  43. conversation.Prompt(lastToken);
  44. // Continue generating text
  45. lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
  46. // Can't save a conversation while RequiresInference is true
  47. if (conversation.RequiresInference)
  48. await executor.Infer();
  49. // Save the conversation again, this time into system memory
  50. using (var state = conversation.Save())
  51. {
  52. conversation.Dispose();
  53. AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes");
  54. // Now create a new conversation by loading that state
  55. conversation = executor.Load("demo_conversation.state");
  56. AnsiConsole.WriteLine("Loaded state");
  57. }
  58. // Prompt it again with the last token, so we can continue generating
  59. conversation.Rewind(1);
  60. conversation.Prompt(lastToken);
  61. // Continue generating text
  62. await GenerateTokens(executor, conversation, sampler, decoder, n_len);
  63. // Display final ouput
  64. AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]");
  65. }
  66. private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15)
  67. {
  68. var token = (LLamaToken)0;
  69. for (var i = 0; i < count; i++)
  70. {
  71. // Run inference
  72. await executor.Infer();
  73. // Use sampling pipeline to pick a token
  74. token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty);
  75. // Add it to the decoder, so it can be converted into text later
  76. decoder.Add(token);
  77. // Prompt the conversation with the token
  78. conversation.Prompt(token);
  79. }
  80. return token;
  81. }
  82. }