You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorGuidance.cs 5.3 kB

April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. using Spectre.Console;
  6. namespace LLama.Examples.Examples;
  7. /// <summary>
  8. /// This demonstrates using a batch to generate two sequences and then using one
  9. /// sequence as the negative guidance ("classifier free guidance") for the other.
  10. /// </summary>
  11. public class BatchedExecutorGuidance
  12. {
  13. private const int n_len = 32;
  14. public static async Task Run()
  15. {
  16. string modelPath = UserSettings.GetModelPath();
  17. var parameters = new ModelParams(modelPath);
  18. using var model = LLamaWeights.LoadFromFile(parameters);
  19. var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
  20. var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
  21. var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);
  22. // Create an executor that can evaluate a batch of conversations together
  23. using var executor = new BatchedExecutor(model, parameters);
  24. // Print some info
  25. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  26. Console.WriteLine($"Created executor with model: {name}");
  27. // Load the two prompts into two conversations
  28. using var guided = executor.Create();
  29. guided.Prompt(positivePrompt);
  30. using var guidance = executor.Create();
  31. guidance.Prompt(negativePrompt);
  32. // Run inference to evaluate prompts
  33. await AnsiConsole
  34. .Status()
  35. .Spinner(Spinner.Known.Line)
  36. .StartAsync("Evaluating Prompts...", _ => executor.Infer());
  37. // Fork the "guided" conversation. We'll run this one without guidance for comparison
  38. using var unguided = guided.Fork();
  39. // Run inference loop
  40. var unguidedSampler = new GuidedSampler(null, weight);
  41. var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
  42. var guidedSampler = new GuidedSampler(guidance, weight);
  43. var guidedDecoder = new StreamingTokenDecoder(executor.Context);
  44. await AnsiConsole
  45. .Progress()
  46. .StartAsync(async progress =>
  47. {
  48. var reporter = progress.AddTask("Running Inference", maxValue: n_len);
  49. for (var i = 0; i < n_len; i++)
  50. {
  51. if (i != 0)
  52. await executor.Infer();
  53. // Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
  54. // guidance. This serves as a comparison to show the effect of guidance.
  55. var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
  56. unguidedDecoder.Add(u);
  57. unguided.Prompt(u);
  58. // Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
  59. // to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
  60. var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
  61. guidedDecoder.Add(g);
  62. // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
  63. guided.Prompt(g);
  64. guidance.Prompt(g);
  65. // Early exit if we reach the natural end of the guided sentence
  66. if (g == model.Tokens.EOS)
  67. break;
  68. // Update progress bar
  69. reporter.Increment(1);
  70. }
  71. });
  72. AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
  73. AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
  74. }
  75. private class GuidedSampler(Conversation? guidance, float weight)
  76. : BaseSamplingPipeline
  77. {
  78. public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
  79. {
  80. }
  81. public override ISamplingPipeline Clone()
  82. {
  83. throw new NotSupportedException();
  84. }
  85. protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  86. {
  87. if (guidance == null)
  88. return;
  89. // Get the logits generated by the guidance sequences
  90. var guidanceLogits = guidance.Sample();
  91. // Use those logits to guide this sequence
  92. NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
  93. }
  94. protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
  95. {
  96. candidates.Temperature(ctx, 0.8f);
  97. candidates.TopK(ctx, 25);
  98. return candidates.SampleToken(ctx);
  99. }
  100. }
  101. }