You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorGuidance.cs 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. using Spectre.Console;
  6. namespace LLama.Examples.Examples;
  7. /// <summary>
  8. /// This demonstrates using a batch to generate two sequences and then using one
  9. /// sequence as the negative guidance ("classifier free guidance") for the other.
  10. /// </summary>
  11. public class BatchedExecutorGuidance
  12. {
  13. private const int n_len = 32;
  14. public static async Task Run()
  15. {
  16. string modelPath = UserSettings.GetModelPath();
  17. var parameters = new ModelParams(modelPath);
  18. using var model = LLamaWeights.LoadFromFile(parameters);
  19. var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
  20. var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
  21. var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);
  22. // Create an executor that can evaluate a batch of conversations together
  23. using var executor = new BatchedExecutor(model, parameters);
  24. // Print some info
  25. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  26. Console.WriteLine($"Created executor with model: {name}");
  27. // Load the two prompts into two conversations
  28. using var guided = executor.Prompt(positivePrompt);
  29. using var guidance = executor.Prompt(negativePrompt);
  30. // Run inference to evaluate prompts
  31. await AnsiConsole
  32. .Status()
  33. .Spinner(Spinner.Known.Line)
  34. .StartAsync("Evaluating Prompts...", _ => executor.Infer());
  35. // Fork the "guided" conversation. We'll run this one without guidance for comparison
  36. using var unguided = guided.Fork();
  37. // Run inference loop
  38. var unguidedSampler = new GuidedSampler(null, weight);
  39. var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
  40. var guidedSampler = new GuidedSampler(guidance, weight);
  41. var guidedDecoder = new StreamingTokenDecoder(executor.Context);
  42. await AnsiConsole
  43. .Progress()
  44. .StartAsync(async progress =>
  45. {
  46. var reporter = progress.AddTask("Running Inference", maxValue: n_len);
  47. for (var i = 0; i < n_len; i++)
  48. {
  49. if (i != 0)
  50. await executor.Infer();
  51. // Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
  52. // guidance. This serves as a comparison to show the effect of guidance.
  53. var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
  54. unguidedDecoder.Add(u);
  55. unguided.Prompt(u);
  56. // Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
  57. // to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
  58. var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
  59. guidedDecoder.Add(g);
  60. // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
  61. guided.Prompt(g);
  62. guidance.Prompt(g);
  63. // Early exit if we reach the natural end of the guided sentence
  64. if (g == model.EndOfSentenceToken)
  65. break;
  66. // Update progress bar
  67. reporter.Increment(1);
  68. }
  69. });
  70. AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
  71. AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
  72. }
  73. private class GuidedSampler(Conversation? guidance, float weight)
  74. : BaseSamplingPipeline
  75. {
  76. public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
  77. {
  78. }
  79. public override ISamplingPipeline Clone()
  80. {
  81. throw new NotSupportedException();
  82. }
  83. protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  84. {
  85. if (guidance == null)
  86. return;
  87. // Get the logits generated by the guidance sequences
  88. var guidanceLogits = guidance.Sample();
  89. // Use those logits to guide this sequence
  90. NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
  91. }
  92. protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
  93. {
  94. candidates.Temperature(ctx, 0.8f);
  95. candidates.TopK(ctx, 25);
  96. return candidates.SampleToken(ctx);
  97. }
  98. }
  99. }