You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLama.Native;
  10. using LLama.Sampling;
  11. using LLama.Control;
  12. using Microsoft.Extensions.Logging;
  13. namespace LLama
  14. {
  15. using llama_token = Int32;
  16. /// <summary>
  17. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  18. /// response to current input.
  19. /// </summary>
  20. public class StatelessExecutor
  21. : ILLamaExecutor
  22. {
  23. private readonly LLamaWeights _weights;
  24. private readonly IContextParams _params;
  25. private readonly ILogger? _logger;
  26. /// <summary>
  27. /// The context used by the executor when running the inference.
  28. /// </summary>
  29. public LLamaContext Context { get; private set; }
  30. /// <summary>
  31. /// Create a new stateless executor which will use the given model
  32. /// </summary>
  33. /// <param name="weights"></param>
  34. /// <param name="params"></param>
  35. /// <param name="logger"></param>
  36. public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
  37. {
  38. _weights = weights;
  39. _params = @params;
  40. _logger = logger;
  41. Context = _weights.CreateContext(_params, logger);
  42. Context.Dispose();
  43. }
  44. /// <inheritdoc />
  45. public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  46. {
  47. // Ensure the context from last time is disposed (it always should be)
  48. if (!Context.NativeHandle.IsClosed)
  49. Context.Dispose();
  50. // Create an inference context which will be disposed when this method exits
  51. using var context = _weights.CreateContext(_params, _logger);
  52. Context = context;
  53. await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken))
  54. {
  55. yield return item;
  56. }
  57. }
  58. public static async IAsyncEnumerable<string> InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  59. {
  60. // Sanity check inference params
  61. inferenceParams ??= new InferenceParams();
  62. if (inferenceParams.TokensKeep > context.ContextSize)
  63. throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})");
  64. // Keep track of the last N tokens emitted
  65. var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount);
  66. var lastTokens = new List<llama_token>(repeat_last_n);
  67. for (var i = 0; i < repeat_last_n; i++)
  68. lastTokens.Add(0);
  69. // Tokenize the prompt
  70. var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList();
  71. lastTokens.AddRange(tokens);
  72. var n_past = 1 + tokens.Count;
  73. // Evaluate the prompt
  74. await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken)
  75. .ConfigureAwait(false);
  76. // Begin loop, evaluating one token at a time
  77. var mu = (float?)null;
  78. var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  79. for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
  80. {
  81. llama_token id;
  82. if (inferenceParams.SamplingPipeline is not null)
  83. {
  84. id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens);
  85. }
  86. else
  87. {
  88. // Penalize the generated tokens by various penalties
  89. var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
  90. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  91. // Sample a single token
  92. id = context.Sample(
  93. tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  94. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
  95. inferenceParams.MinP
  96. );
  97. }
  98. // Decode this token into text
  99. var decoded = inferenceParams.Tokenizer.Detokenize(context, id);
  100. yield return decoded;
  101. // Check if the generation should stop
  102. if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded))
  103. break;
  104. lastTokens.Add(id);
  105. tokens.Clear();
  106. tokens.Add(id);
  107. // when run out of context
  108. // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
  109. if (n_past + tokens.Count >= context.ContextSize)
  110. {
  111. var n_left = n_past - inferenceParams.TokensKeep - 1;
  112. var n_discard = n_left / 2;
  113. NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
  114. NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
  115. n_past -= n_discard;
  116. }
  117. // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
  118. n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken)
  119. .ConfigureAwait(false);
  120. }
  121. }
  122. }
  123. }