You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLama.Extensions;
  10. using Microsoft.Extensions.Logging;
  11. namespace LLama
  12. {
  13. using llama_token = Int32;
  14. /// <summary>
  15. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  16. /// response to current input.
  17. /// </summary>
  18. public class StatelessExecutor
  19. : ILLamaExecutor
  20. {
  21. private readonly ILogger? _logger;
  22. private readonly LLamaWeights _weights;
  23. private readonly IModelParams _params;
  24. /// <summary>
  25. /// The context used by the executor when running the inference.
  26. /// </summary>
  27. public LLamaContext Context { get; private set; }
  28. /// <summary>
  29. /// Create a new stateless executor which will use the given model
  30. /// </summary>
  31. /// <param name="weights"></param>
  32. /// <param name="params"></param>
  33. /// <param name="logger"></param>
  34. public StatelessExecutor(LLamaWeights weights, IModelParams @params, ILogger logger = null!)
  35. {
  36. _logger = logger;
  37. _weights = weights;
  38. _params = @params;
  39. Context = _weights.CreateContext(_params);
  40. Context.Dispose();
  41. }
  42. /// <summary>
  43. /// Create a new stateless executor which will use the model used to create the given context
  44. /// </summary>
  45. /// <param name="context"></param>
  46. [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
  47. public StatelessExecutor(LLamaContext context)
  48. {
  49. _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
  50. _params = context.Params;
  51. Context = _weights.CreateContext(_params);
  52. Context.Dispose();
  53. }
  54. /// <inheritdoc />
  55. public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  56. {
  57. using var context = _weights.CreateContext(_params);
  58. Context = context;
  59. if (!Context.NativeHandle.IsClosed)
  60. Context.Dispose();
  61. Context = _weights.CreateContext(Context.Params);
  62. if (inferenceParams != null)
  63. {
  64. if (inferenceParams.TokensKeep > Context.ContextSize)
  65. throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
  66. }
  67. cancellationToken.ThrowIfCancellationRequested();
  68. var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
  69. inferenceParams ??= new InferenceParams();
  70. var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
  71. for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++)
  72. lastTokens.Add(0);
  73. var tokens = Context.Tokenize(text).ToList();
  74. await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
  75. .ConfigureAwait(false);
  76. lastTokens.AddRange(tokens);
  77. var n_past = 1 + tokens.Count;
  78. var mu = (float?)null;
  79. var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  80. for(var i = 0; i < max_tokens; i++)
  81. {
  82. if (cancellationToken.IsCancellationRequested)
  83. break;
  84. var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
  85. var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
  86. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  87. var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  88. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
  89. lastTokens.Add(id);
  90. yield return Context.TokenToString(id);
  91. tokens.Clear();
  92. tokens.Add(id);
  93. // Check if any of the antiprompts have been generated
  94. if (lastTokens.TokensEndsWithAnyString(antiprompts, Context))
  95. break;
  96. // when run out of context
  97. // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
  98. if (n_past + tokens.Count > Context.ContextSize)
  99. {
  100. var n_left = n_past - inferenceParams.TokensKeep;
  101. n_past = Math.Max(1, inferenceParams.TokensKeep);
  102. tokens.Clear();
  103. tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
  104. }
  105. // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
  106. n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
  107. .ConfigureAwait(false);
  108. }
  109. }
  110. }
  111. }