You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaInteractExecutor.cs 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. using LLama.Common;
  2. using LLama.Native;
  3. using LLama.Abstractions;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.IO;
  7. using System.Linq;
  8. using System.Text.Json;
  9. using System.Text.Json.Serialization;
  10. using System.Threading.Tasks;
  11. using LLama.Extensions;
  12. using Microsoft.Extensions.Logging;
  13. namespace LLama
  14. {
  15. /// <summary>
  16. /// The LLama executor for interactive mode.
  17. /// </summary>
  18. public class InteractiveExecutor : StatefulExecutorBase
  19. {
  20. private bool _is_prompt_run = true;
  21. private readonly LLamaToken _llama_token_newline;
  22. /// <summary>
  23. ///
  24. /// </summary>
  25. /// <param name="context"></param>
  26. /// <param name="logger"></param>
  27. public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
  28. : base(context, logger)
  29. {
  30. _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
  31. }
  32. /// <inheritdoc />
  33. public override ExecutorBaseState GetStateData()
  34. {
  35. InteractiveExecutorState state = new()
  36. {
  37. ConsumedSessionCount = _n_session_consumed,
  38. EmbedInps = _embed_inps,
  39. IsPromptRun = _is_prompt_run,
  40. ConsumedTokensCount = _consumedTokensCount,
  41. Embeds = _embeds,
  42. LastTokens = _last_n_tokens.ToArray(),
  43. MatchingSessionTokensCount = _n_matching_session_tokens,
  44. PastTokensCount = _pastTokensCount,
  45. SessionFilePath = _pathSession,
  46. SessionTokens = _session_tokens,
  47. LastTokensCapacity = _last_n_tokens.Capacity,
  48. MirostatMu = MirostatMu
  49. };
  50. return state;
  51. }
  52. /// <inheritdoc />
  53. public override Task LoadState(ExecutorBaseState data)
  54. {
  55. if (data is InteractiveExecutorState state)
  56. {
  57. _n_session_consumed = state.ConsumedSessionCount;
  58. _embed_inps = state.EmbedInps;
  59. _is_prompt_run = state.IsPromptRun;
  60. _consumedTokensCount = state.ConsumedTokensCount;
  61. _embeds = state.Embeds;
  62. _last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
  63. _n_matching_session_tokens = state.MatchingSessionTokensCount;
  64. _pastTokensCount = state.PastTokensCount;
  65. _pathSession = state.SessionFilePath;
  66. _session_tokens = state.SessionTokens;
  67. }
  68. else
  69. throw new ArgumentException("Invalid state data type.");
  70. return Task.CompletedTask;
  71. }
  72. /// <inheritdoc />
  73. public override async Task SaveState(string filename)
  74. {
  75. var state = (InteractiveExecutorState)GetStateData();
  76. using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
  77. {
  78. await JsonSerializer.SerializeAsync(fs, state);
  79. }
  80. }
  81. /// <inheritdoc />
  82. public override async Task LoadState(string filename)
  83. {
  84. using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
  85. {
  86. var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
  87. await LoadState(state);
  88. }
  89. }
  90. /// <summary>
  91. /// Define whether to continue the loop to generate responses.
  92. /// </summary>
  93. /// <returns></returns>
  94. protected override Task<bool> GetLoopCondition(InferStateArgs args)
  95. {
  96. return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
  97. }
  98. /// <inheritdoc />
  99. protected override Task PreprocessInputs(string text, InferStateArgs args)
  100. {
  101. if (_is_prompt_run)
  102. {
  103. // When running the first input (prompt) in inteactive mode, we should specially process it.
  104. _embed_inps = Context.Tokenize(text, true).ToList();
  105. }
  106. else
  107. {
  108. if (!text.EndsWith("\n"))
  109. {
  110. text += "\n";
  111. }
  112. var line_inp = Context.Tokenize(text, false);
  113. _embed_inps.AddRange(line_inp);
  114. args.RemainedTokens -= line_inp.Length;
  115. }
  116. return Task.CompletedTask;
  117. }
  118. /// <summary>
  119. /// Return whether to break the generation.
  120. /// </summary>
  121. /// <param name="inferenceParams"></param>
  122. /// <param name="args"></param>
  123. /// <returns></returns>
  124. protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
  125. {
  126. if (_embed_inps.Count <= _consumedTokensCount)
  127. {
  128. if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
  129. args.WaitForInput = true;
  130. if (_pastTokensCount > 0 && args.WaitForInput)
  131. return (true, Array.Empty<string>());
  132. }
  133. if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
  134. {
  135. return (true, new[] { " [end of text]\n" });
  136. }
  137. if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
  138. {
  139. args.RemainedTokens = inferenceParams.MaxTokens;
  140. args.WaitForInput = true;
  141. }
  142. return (false, Array.Empty<string>());
  143. }
  144. /// <inheritdoc />
  145. protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
  146. {
  147. if (_embeds.Count > 0)
  148. {
  149. _is_prompt_run = false;
  150. if (_pastTokensCount + _embeds.Count > Context.ContextSize)
  151. {
  152. HandleRunOutOfContext(inferenceParams.TokensKeep);
  153. }
  154. TryReuseMathingPrefix();
  155. _pastTokensCount = Context.Eval(_embeds, _pastTokensCount);
  156. if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
  157. {
  158. _session_tokens.AddRange(_embeds);
  159. _n_session_consumed = _session_tokens.Count;
  160. }
  161. }
  162. _embeds.Clear();
  163. if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
  164. {
  165. var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
  166. // optionally save the session on first sample (for faster prompt loading next time)
  167. if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
  168. {
  169. args.NeedToSaveSession = false;
  170. SaveSessionFile(_pathSession);
  171. }
  172. LLamaToken id;
  173. if (inferenceParams.SamplingPipeline is not null)
  174. {
  175. id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
  176. }
  177. else
  178. {
  179. var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
  180. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  181. var mu = MirostatMu;
  182. id = Context.Sample(
  183. tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  184. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
  185. inferenceParams.MinP
  186. );
  187. MirostatMu = mu;
  188. }
  189. _last_n_tokens.Enqueue(id);
  190. if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
  191. {
  192. id = _llama_token_newline;
  193. if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
  194. {
  195. var first_antiprompt = Context.Tokenize(args.Antiprompts[0], false);
  196. _embed_inps.AddRange(first_antiprompt);
  197. }
  198. }
  199. _embeds.Add(id);
  200. args.RemainedTokens--;
  201. args.ReturnValue = true;
  202. }
  203. else
  204. {
  205. while (_embed_inps.Count > _consumedTokensCount)
  206. {
  207. _embeds.Add(_embed_inps[_consumedTokensCount]);
  208. _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
  209. _consumedTokensCount++;
  210. if (_embeds.Count >= Context.Params.BatchSize)
  211. {
  212. break;
  213. }
  214. }
  215. }
  216. }
  217. /// <summary>
  218. /// The descriptor of the state of the interactive executor.
  219. /// </summary>
  220. public class InteractiveExecutorState
  221. : ExecutorBaseState
  222. {
  223. /// <summary>
  224. /// Whether the executor is running for the first time (running the prompt).
  225. /// </summary>
  226. [JsonPropertyName("is_prompt_run")]
  227. public bool IsPromptRun { get; set; }
  228. }
  229. }
  230. }