You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using System.Text;
  8. using System.Threading;
  9. namespace LLama
  10. {
  11. using llama_token = Int32;
  12. /// <summary>
  13. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  14. /// response to current input.
  15. /// </summary>
  16. public class StatelessExecutor
  17. : ILLamaExecutor
  18. {
  19. private readonly LLamaWeights _weights;
  20. private readonly IModelParams _params;
  21. /// <summary>
  22. /// The context used by the executor when running the inference.
  23. /// </summary>
  24. public LLamaContext Context { get; private set; }
  25. /// <summary>
  26. /// Create a new stateless executor which will use the given model
  27. /// </summary>
  28. /// <param name="weights"></param>
  29. /// <param name="params"></param>
  30. public StatelessExecutor(LLamaWeights weights, IModelParams @params)
  31. {
  32. _weights = weights;
  33. _params = @params;
  34. Context = _weights.CreateContext(_params);
  35. Context.Dispose();
  36. }
  37. /// <summary>
  38. /// Create a new stateless executor which will use the model used to create the given context
  39. /// </summary>
  40. /// <param name="context"></param>
  41. [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
  42. public StatelessExecutor(LLamaContext context)
  43. {
  44. _weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding));
  45. _params = context.Params;
  46. Context = _weights.CreateContext(_params);
  47. Context.Dispose();
  48. }
  49. /// <inheritdoc />
  50. public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  51. {
  52. using var context = _weights.CreateContext(_params);
  53. Context = context;
  54. if (!Context.NativeHandle.IsClosed)
  55. Context.Dispose();
  56. Context = _weights.CreateContext(Context.Params);
  57. if (inferenceParams != null)
  58. {
  59. if (inferenceParams.TokensKeep > Context.ContextSize)
  60. throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
  61. }
  62. cancellationToken.ThrowIfCancellationRequested();
  63. var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
  64. var n_past = 1;
  65. inferenceParams ??= new InferenceParams();
  66. var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
  67. for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++)
  68. lastTokens.Add(0);
  69. var tokens = Context.Tokenize(text).ToList();
  70. var n_prompt_tokens = tokens.Count;
  71. Context.Eval(tokens, n_past);
  72. lastTokens.AddRange(tokens);
  73. n_past += n_prompt_tokens;
  74. var mu = (float?)null;
  75. var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  76. for(var i = 0; i < max_tokens; i++)
  77. {
  78. if (cancellationToken.IsCancellationRequested)
  79. break;
  80. var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
  81. var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
  82. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  83. var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  84. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
  85. lastTokens.Add(id);
  86. var response = Context.TokenToString(id);
  87. yield return response;
  88. tokens.Clear();
  89. tokens.Add(id);
  90. if (EndsWithAntiprompt(lastTokens, antiprompts))
  91. break;
  92. // when run out of context
  93. // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
  94. if (n_past + tokens.Count > Context.ContextSize)
  95. {
  96. var n_left = n_past - inferenceParams.TokensKeep;
  97. n_past = Math.Max(1, inferenceParams.TokensKeep);
  98. tokens.Clear();
  99. tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
  100. }
  101. n_past = Context.Eval(tokens, n_past);
  102. }
  103. }
  104. /// <summary>
  105. /// Check if the given tokens list ends with any of the antiprompts
  106. /// </summary>
  107. /// <param name="tokens"></param>
  108. /// <param name="antiprompts"></param>
  109. /// <returns></returns>
  110. private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts)
  111. {
  112. if (antiprompts.Count == 0 || tokens.Count == 0)
  113. return false;
  114. var builder = new StringBuilder();
  115. foreach (var token in tokens)
  116. builder.Append(Context.TokenToString(token));
  117. var last_output = builder.ToString();
  118. foreach (var antiprompt in antiprompts)
  119. {
  120. if (last_output.EndsWith(antiprompt))
  121. return true;
  122. }
  123. return false;
  124. }
  125. /// <inheritdoc />
  126. public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  127. {
  128. foreach (var result in Infer(text, inferenceParams, cancellationToken))
  129. {
  130. yield return result;
  131. }
  132. }
  133. }
  134. }