You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using System.Threading;
  8. using LLama.Extensions;
  9. using LLama.Native;
  10. namespace LLama
  11. {
  12. using llama_token = Int32;
  13. /// <summary>
  14. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  15. /// response to current input.
  16. /// </summary>
  17. public class StatelessExecutor
  18. : ILLamaExecutor
  19. {
  20. private readonly LLamaWeights _weights;
  21. private readonly IContextParams _params;
  22. /// <summary>
  23. /// The context used by the executor when running the inference.
  24. /// </summary>
  25. public LLamaContext Context { get; private set; }
  26. /// <summary>
  27. /// Create a new stateless executor which will use the given model
  28. /// </summary>
  29. /// <param name="weights"></param>
  30. /// <param name="params"></param>
  31. public StatelessExecutor(LLamaWeights weights, IContextParams @params)
  32. {
  33. _weights = weights;
  34. _params = @params;
  35. Context = _weights.CreateContext(_params);
  36. Context.Dispose();
  37. }
  38. /// <inheritdoc />
  39. public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  40. {
  41. using var context = _weights.CreateContext(_params);
  42. Context = context;
  43. if (!Context.NativeHandle.IsClosed)
  44. Context.Dispose();
  45. Context = _weights.CreateContext(Context.Params);
  46. if (inferenceParams != null)
  47. {
  48. if (inferenceParams.TokensKeep > Context.ContextSize)
  49. throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
  50. }
  51. cancellationToken.ThrowIfCancellationRequested();
  52. var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
  53. var n_past = 1;
  54. inferenceParams ??= new InferenceParams();
  55. var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
  56. for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++)
  57. lastTokens.Add(0);
  58. var tokens = Context.Tokenize(text).ToList();
  59. var n_prompt_tokens = tokens.Count;
  60. Context.Eval(tokens, n_past);
  61. lastTokens.AddRange(tokens);
  62. n_past += n_prompt_tokens;
  63. var mu = (float?)null;
  64. var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  65. for(var i = 0; i < max_tokens; i++)
  66. {
  67. if (cancellationToken.IsCancellationRequested)
  68. break;
  69. var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
  70. var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
  71. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  72. var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  73. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
  74. lastTokens.Add(id);
  75. var response = Context.TokenToString(id);
  76. yield return response;
  77. tokens.Clear();
  78. tokens.Add(id);
  79. if (EndsWithAntiprompt(lastTokens, antiprompts))
  80. break;
  81. // when run out of context
  82. // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
  83. if (n_past + tokens.Count >= Context.ContextSize)
  84. {
  85. var n_left = n_past - inferenceParams.TokensKeep - 1;
  86. var n_discard = n_left / 2;
  87. NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
  88. NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
  89. n_past -= n_discard;
  90. tokens.Clear();
  91. tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
  92. }
  93. n_past = Context.Eval(tokens, n_past);
  94. }
  95. }
  96. /// <summary>
  97. /// Check if the given tokens list ends with any of the antiprompts
  98. /// </summary>
  99. /// <param name="tokens"></param>
  100. /// <param name="antiprompts"></param>
  101. /// <returns></returns>
  102. private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts)
  103. {
  104. return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
  105. }
  106. }
  107. }