You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorGuidance.md 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Batched executor - basic guidance
  2. ```cs
  3. using LLama.Batched;
  4. using LLama.Common;
  5. using LLama.Native;
  6. using LLama.Sampling;
  7. using Spectre.Console;
  8. namespace LLama.Examples.Examples;
  9. /// <summary>
  10. /// This demonstrates using a batch to generate two sequences and then using one
  11. /// sequence as the negative guidance ("classifier free guidance") for the other.
  12. /// </summary>
  13. public class BatchedExecutorGuidance
  14. {
  15. private const int n_len = 32;
  16. public static async Task Run()
  17. {
  18. string modelPath = UserSettings.GetModelPath();
  19. var parameters = new ModelParams(modelPath);
  20. using var model = LLamaWeights.LoadFromFile(parameters);
  21. var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
  22. var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
  23. var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);
  24. // Create an executor that can evaluate a batch of conversations together
  25. using var executor = new BatchedExecutor(model, parameters);
  26. // Print some info
  27. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  28. Console.WriteLine($"Created executor with model: {name}");
  29. // Load the two prompts into two conversations
  30. using var guided = executor.Create();
  31. guided.Prompt(positivePrompt);
  32. using var guidance = executor.Create();
  33. guidance.Prompt(negativePrompt);
  34. // Run inference to evaluate prompts
  35. await AnsiConsole
  36. .Status()
  37. .Spinner(Spinner.Known.Line)
  38. .StartAsync("Evaluating Prompts...", _ => executor.Infer());
  39. // Fork the "guided" conversation. We'll run this one without guidance for comparison
  40. using var unguided = guided.Fork();
  41. // Run inference loop
  42. var unguidedSampler = new GuidedSampler(null, weight);
  43. var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
  44. var guidedSampler = new GuidedSampler(guidance, weight);
  45. var guidedDecoder = new StreamingTokenDecoder(executor.Context);
  46. await AnsiConsole
  47. .Progress()
  48. .StartAsync(async progress =>
  49. {
  50. var reporter = progress.AddTask("Running Inference", maxValue: n_len);
  51. for (var i = 0; i < n_len; i++)
  52. {
  53. if (i != 0)
  54. await executor.Infer();
  55. // Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
  56. // guidance. This serves as a comparison to show the effect of guidance.
  57. var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
  58. unguidedDecoder.Add(u);
  59. unguided.Prompt(u);
  60. // Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
  61. // to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
  62. var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
  63. guidedDecoder.Add(g);
  64. // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
  65. guided.Prompt(g);
  66. guidance.Prompt(g);
  67. // Early exit if we reach the natural end of the guided sentence
  68. if (g == model.EndOfSentenceToken)
  69. break;
  70. // Update progress bar
  71. reporter.Increment(1);
  72. }
  73. });
  74. AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
  75. AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
  76. }
  77. private class GuidedSampler(Conversation? guidance, float weight)
  78. : BaseSamplingPipeline
  79. {
  80. public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
  81. {
  82. }
  83. public override ISamplingPipeline Clone()
  84. {
  85. throw new NotSupportedException();
  86. }
  87. protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  88. {
  89. if (guidance == null)
  90. return;
  91. // Get the logits generated by the guidance sequences
  92. var guidanceLogits = guidance.Sample();
  93. // Use those logits to guide this sequence
  94. NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
  95. }
  96. protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
  97. {
  98. candidates.Temperature(ctx, 0.8f);
  99. candidates.TopK(ctx, 25);
  100. return candidates.SampleToken(ctx);
  101. }
  102. }
  103. }
  104. ```