You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLama.Native;
  10. using LLama.Sampling;
  11. using Microsoft.Extensions.Logging;
  12. namespace LLama
  13. {
  14. using llama_token = Int32;
  15. /// <summary>
  16. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  17. /// response to current input.
  18. /// </summary>
  19. public class StatelessExecutor
  20. : ILLamaExecutor
  21. {
  22. private readonly LLamaWeights _weights;
  23. private readonly IContextParams _params;
  24. private readonly ILogger? _logger;
  25. /// <summary>
  26. /// The context used by the executor when running the inference.
  27. /// </summary>
  28. public LLamaContext Context { get; private set; }
  29. /// <summary>
  30. /// Create a new stateless executor which will use the given model
  31. /// </summary>
  32. /// <param name="weights"></param>
  33. /// <param name="params"></param>
  34. /// <param name="logger"></param>
  35. public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
  36. {
  37. _weights = weights;
  38. _params = @params;
  39. _logger = logger;
  40. Context = _weights.CreateContext(_params, logger);
  41. Context.Dispose();
  42. }
  43. /// <inheritdoc />
  44. public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  45. {
  46. // Ensure the context from last time is disposed (it always hould be)
  47. if (!Context.NativeHandle.IsClosed)
  48. Context.Dispose();
  49. // Create an inference context which will be disposed when this method exits
  50. using var context = _weights.CreateContext(_params, _logger);
  51. Context = context;
  52. // Reset the sampling pipeline (if there is one)
  53. inferenceParams?.SamplingPipeline?.Reset();
  54. // Sanity check inference params
  55. inferenceParams ??= new InferenceParams();
  56. if (inferenceParams.TokensKeep > Context.ContextSize)
  57. throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
  58. // Create decoders for the token stream
  59. var decoder = new StreamingTokenDecoder(Context);
  60. var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
  61. // Keep track of the last N tokens emitted
  62. var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
  63. var lastTokens = new List<llama_token>(repeat_last_n);
  64. for (var i = 0; i < repeat_last_n; i++)
  65. lastTokens.Add(0);
  66. // Tokenize the prompt
  67. var tokens = Context.Tokenize(prompt).ToList();
  68. lastTokens.AddRange(tokens);
  69. var n_past = 1 + tokens.Count;
  70. // Evaluate the prompt
  71. await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
  72. .ConfigureAwait(false);
  73. // Begin loop, evaluating one token at a time
  74. var mu = (float?)null;
  75. var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  76. for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
  77. {
  78. llama_token id;
  79. if (inferenceParams.SamplingPipeline is not null)
  80. {
  81. id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
  82. }
  83. else
  84. {
  85. // Penalize the generated tokens by various penalties
  86. var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
  87. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  88. // Sample a single token
  89. id = Context.Sample(
  90. tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  91. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
  92. inferenceParams.MinP
  93. );
  94. }
  95. // Decode this token into text
  96. decoder.Add(id);
  97. var decoded = decoder.Read();
  98. yield return decoded;
  99. // Check if any of the antiprompts have been generated
  100. if (antiprocessor.Add(decoded))
  101. break;
  102. lastTokens.Add(id);
  103. tokens.Clear();
  104. tokens.Add(id);
  105. // when run out of context
  106. // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
  107. if (n_past + tokens.Count >= Context.ContextSize)
  108. {
  109. var n_left = n_past - inferenceParams.TokensKeep - 1;
  110. var n_discard = n_left / 2;
  111. NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
  112. NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
  113. n_past -= n_discard;
  114. }
  115. // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
  116. n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
  117. .ConfigureAwait(false);
  118. }
  119. }
  120. }
  121. }