You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Conversation.cs 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. using System;
  2. using System.Collections.Generic;
  3. using LLama.Native;
  4. namespace LLama.Batched;
  5. /// <summary>
  6. /// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
  7. /// </summary>
  8. public sealed class Conversation
  9. : IDisposable
  10. {
  11. private ulong _requiredEpoch;
  12. private LLamaPos _end;
  13. private int _batchIndex;
  14. private bool _disposed;
  15. /// <summary>
  16. /// The executor which this conversation belongs to
  17. /// </summary>
  18. public BatchedExecutor Executor { get; }
  19. /// <summary>
  20. /// Unique ID for this conversation
  21. /// </summary>
  22. public LLamaSeqId ConversationId { get; }
  23. /// <summary>
  24. /// Total number of tokens in this conversation, cannot exceed the context length.
  25. /// </summary>
  26. public int TokenCount => _end.Value;
  27. /// <summary>
  28. /// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
  29. /// </summary>
  30. public bool IsDisposed => _disposed || Executor.IsDisposed;
  31. /// <summary>
  32. /// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
  33. /// </summary>
  34. public bool RequiresInference => _requiredEpoch > Executor.Epoch;
  35. /// <summary>
  36. /// Indicates that this conversation should be sampled.
  37. /// </summary>
  38. public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
  39. #region construction/destruction
  40. internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
  41. {
  42. ConversationId = id;
  43. Executor = batch;
  44. _end = end;
  45. }
  46. /// <summary>
  47. /// Finalizer for Conversation
  48. /// </summary>
  49. ~Conversation()
  50. {
  51. Dispose();
  52. }
  53. /// <summary>
  54. /// End this conversation, freeing all resources used by it
  55. /// </summary>
  56. /// <exception cref="ObjectDisposedException"></exception>
  57. public void Dispose()
  58. {
  59. if (IsDisposed)
  60. return;
  61. _disposed = true;
  62. // Remove this conversation from the KV cache
  63. Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end);
  64. // Prevent finalizer from running
  65. GC.SuppressFinalize(this);
  66. }
  67. private void AssertNotDisposed()
  68. {
  69. if (Executor.IsDisposed)
  70. throw new ObjectDisposedException(nameof(BatchedExecutor));
  71. if (IsDisposed)
  72. throw new ObjectDisposedException(nameof(Conversation));
  73. }
  74. #endregion
  75. /// <summary>
  76. /// Create a copy of the current conversation
  77. /// </summary>
  78. /// <remarks>The copy shares internal state, so consumes very little extra memory.</remarks>
  79. /// <returns></returns>
  80. /// <exception cref="ObjectDisposedException"></exception>
  81. public Conversation Fork()
  82. {
  83. AssertNotDisposed();
  84. if (RequiresInference)
  85. throw new CannotForkWhileRequiresInferenceException();
  86. // Create a new conversation which references the current position in this one
  87. var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
  88. {
  89. _batchIndex = _batchIndex,
  90. _requiredEpoch = _requiredEpoch,
  91. };
  92. // Assign tokens to the new sequence
  93. NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
  94. return c;
  95. }
  96. #region sample
  97. /// <summary>
  98. /// Get the logits from this conversation, ready for sampling
  99. /// </summary>
  100. /// <returns></returns>
  101. /// <exception cref="ObjectDisposedException"></exception>
  102. /// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
  103. /// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
  104. public ReadOnlySpan<float> Sample()
  105. {
  106. AssertNotDisposed();
  107. if (_requiredEpoch < Executor.Epoch)
  108. throw new CannotSampleRequiresPromptException();
  109. if (_requiredEpoch > Executor.Epoch)
  110. throw new CannotSampleRequiresInferenceException();
  111. return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
  112. }
  113. #endregion
  114. #region prompt
  115. private void AssertCanBePrompted()
  116. {
  117. AssertNotDisposed();
  118. if (RequiresInference)
  119. throw new AlreadyPromptedConversationException();
  120. }
  121. /// <summary>
  122. /// Add tokens to this conversation
  123. /// </summary>
  124. /// <param name="input"></param>
  125. /// <returns></returns>
  126. public void Prompt(string input)
  127. {
  128. AssertCanBePrompted();
  129. Prompt(Executor.Context.Tokenize(input));
  130. }
  131. /// <summary>
  132. /// Add tokens to this conversation
  133. /// </summary>
  134. /// <param name="tokens"></param>
  135. /// <returns></returns>
  136. /// <exception cref="ObjectDisposedException"></exception>
  137. public void Prompt(IReadOnlyList<LLamaToken> tokens)
  138. {
  139. AssertCanBePrompted();
  140. // Add the prompt to the batch
  141. for (var i = 0; i < tokens.Count; i++)
  142. _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
  143. // Mark this conversation as needing inference/sampling
  144. _requiredEpoch = Executor.Epoch + 1;
  145. }
  146. /// <summary>
  147. /// Add a single token to this conversation
  148. /// </summary>
  149. /// <param name="token"></param>
  150. /// <returns></returns>
  151. /// <exception cref="ObjectDisposedException"></exception>
  152. /// <exception cref="InvalidOperationException"></exception>
  153. public void Prompt(LLamaToken token)
  154. {
  155. AssertCanBePrompted();
  156. // Add this token as input
  157. _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
  158. // Mark this conversation as needing inference/sampling
  159. _requiredEpoch = Executor.Epoch + 1;
  160. }
  161. #endregion
  162. #region modify
  163. /// <summary>
  164. /// Directly modify the KV cache of this conversation
  165. /// </summary>
  166. /// <param name="modifier"></param>
  167. /// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
  168. public void Modify(ModifyKvCache modifier)
  169. {
  170. AssertNotDisposed();
  171. if (RequiresInference)
  172. throw new CannotModifyWhileRequiresInferenceException();
  173. // do whatever the modification is
  174. _end = modifier.Invoke(_end, new KvAccessor(this));
  175. // Set the epoch down to zero, this ensures that this conversation
  176. // cannot be sampled until it is prompted again.
  177. _requiredEpoch = 0;
  178. }
  179. /// <summary>
  180. /// Provides direct access to the KV cache of a <see cref="Conversation"/>.
  181. /// See <see cref="Modify"/> for how to use this.
  182. /// </summary>
  183. public readonly ref struct KvAccessor
  184. {
  185. private readonly Conversation _conversation;
  186. internal KvAccessor(Conversation conversation)
  187. {
  188. _conversation = conversation;
  189. }
  190. #region remove
  191. /// <summary>
  192. /// Removes all tokens that have positions in [start, end)
  193. /// </summary>
  194. /// <param name="start">Start position (inclusive)</param>
  195. /// <param name="end">End position (exclusive)</param>
  196. public void Remove(LLamaPos start, LLamaPos end)
  197. {
  198. _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
  199. }
  200. /// <summary>
  201. /// Removes all tokens starting from the given position
  202. /// </summary>
  203. /// <param name="start">Start position (inclusive)</param>
  204. /// <param name="count">Number of tokens</param>
  205. public void Remove(LLamaPos start, int count)
  206. {
  207. if (count <= 0)
  208. return;
  209. var end = start.Value + count;
  210. _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
  211. }
  212. #endregion
  213. #region shift
  214. /// <summary>
  215. /// Adds relative position "delta" to all tokens that have positions in [p0, p1).
  216. /// If the KV cache is RoPEd, the KV data is updated
  217. /// accordingly
  218. /// </summary>
  219. /// <param name="start">Start position (inclusive)</param>
  220. /// <param name="end">End position (exclusive)</param>
  221. /// <param name="delta">Amount to add on to each token position</param>
  222. public void Shift(LLamaPos start, LLamaPos end, int delta)
  223. {
  224. _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
  225. }
  226. #endregion
  227. #region divide
  228. /// <summary>
  229. /// Integer division of the positions by factor of `d > 1`.
  230. /// If the KV cache is RoPEd, the KV data is updated accordingly.
  231. /// </summary>
  232. /// <param name="start">Start position (inclusive). If less than zero, it is clamped to zero.</param>
  233. /// <param name="end">End position (exclusive). If less than zero, it is treated as "infinity".</param>
  234. /// <param name="divisor">Amount to divide each position by.</param>
  235. public void Divide(LLamaPos start, LLamaPos end, int divisor)
  236. {
  237. if (divisor <= 0)
  238. throw new ArgumentOutOfRangeException(nameof(divisor));
  239. _conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
  240. }
  241. #endregion
  242. }
  243. /// <summary>
  244. /// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly
  245. /// </summary>
  246. /// <param name="end">The current end token of this conversation</param>
  247. /// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param>
  248. /// <returns>The new end token position</returns>
  249. public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
  250. #endregion
  251. }