You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Diagnostics.CodeAnalysis;
  7. using System.Linq;
  8. using System.Runtime.CompilerServices;
  9. using System.Text;
  10. using System.Threading;
  11. namespace LLama
  12. {
  13. using llama_token = Int32;
  14. /// <summary>
  15. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  16. /// response to current input.
  17. /// </summary>
  18. public class StatelessExecutor : ILLamaExecutor
  19. {
  20. private LLamaModel _model;
  21. private byte[] _originalState;
  22. /// <summary>
  23. /// The mode used by the executor when running the inference.
  24. /// </summary>
  25. public LLamaModel Model => _model;
  26. /// <summary>
  27. ///
  28. /// </summary>
  29. /// <param name="model">The LLama model.</param>
  30. public StatelessExecutor(LLamaModel model)
  31. {
  32. _model = model;
  33. var tokens = model.Tokenize(" ", true);
  34. Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, tokens.Count(), 0, _model.Params.Threads);
  35. _originalState = model.GetStateData();
  36. }
  37. /// <inheritdoc />
  38. public IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  39. {
  40. cancellationToken.ThrowIfCancellationRequested();
  41. int n_past = 1;
  42. if(inferenceParams is null)
  43. {
  44. inferenceParams = new InferenceParams();
  45. }
  46. List<llama_token> lastTokens = new(inferenceParams.RepeatLastTokensCount);
  47. for(int i = 0; i < lastTokens.Count; i++)
  48. {
  49. lastTokens[i] = 0;
  50. }
  51. List<llama_token> tokens = _model.Tokenize(text, true).ToList();
  52. int n_prompt_tokens = tokens.Count;
  53. Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads);
  54. lastTokens.AddRange(tokens);
  55. n_past += n_prompt_tokens;
  56. int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  57. for(int i = 0; i < max_tokens; i++)
  58. {
  59. if (cancellationToken.IsCancellationRequested)
  60. {
  61. _model.LoadState(_originalState);
  62. break;
  63. }
  64. var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount;
  65. var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
  66. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  67. var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  68. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
  69. lastTokens.Add(id);
  70. string response = Utils.TokenToString(id, _model.NativeHandle, _model.Encoding);
  71. yield return response;
  72. tokens.Clear();
  73. tokens.Add(id);
  74. if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0)
  75. {
  76. string last_output = "";
  77. foreach (var token in lastTokens)
  78. {
  79. last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
  80. }
  81. bool should_break = false;
  82. foreach (var antiprompt in inferenceParams.AntiPrompts)
  83. {
  84. if (last_output.EndsWith(antiprompt))
  85. {
  86. should_break = true;
  87. break;
  88. }
  89. }
  90. if (should_break)
  91. {
  92. break;
  93. }
  94. }
  95. // when run out of context
  96. if (n_past + tokens.Count > _model.ContextSize)
  97. {
  98. int n_left = n_past - inferenceParams.TokensKeep;
  99. n_past = Math.Max(1, inferenceParams.TokensKeep);
  100. // insert n_left/2 tokens at the start of embed from last_n_tokens
  101. tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_model.ContextSize - n_left / 2 - tokens.Count));
  102. }
  103. n_past = _model.Eval(tokens.ToArray(), n_past);
  104. }
  105. _model.LoadState(_originalState);
  106. }
  107. /// <inheritdoc />
  108. public async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  109. {
  110. foreach (var result in Infer(text, inferenceParams, cancellationToken))
  111. {
  112. yield return result;
  113. }
  114. }
  115. }
  116. }