You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Conversation.cs 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. using System;
  2. using System.Collections.Generic;
  3. using LLama.Native;
  4. namespace LLama.Batched;
  5. /// <summary>
  6. /// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
  7. /// </summary>
  8. public sealed class Conversation
  9. : IDisposable
  10. {
  11. private ulong _requiredEpoch;
  12. private LLamaPos _end;
  13. private int _batchIndex;
  14. private bool _disposed;
  15. /// <summary>
  16. /// The executor which this conversation belongs to
  17. /// </summary>
  18. public BatchedExecutor Executor { get; }
  19. /// <summary>
  20. /// Unique ID for this conversation
  21. /// </summary>
  22. public LLamaSeqId ConversationId { get; }
  23. /// <summary>
  24. /// Total number of tokens in this conversation, cannot exceed the context length.
  25. /// </summary>
  26. public int TokenCount => _end.Value;
  27. /// <summary>
  28. /// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
  29. /// </summary>
  30. public bool IsDisposed => _disposed || Executor.IsDisposed;
  31. /// <summary>
  32. /// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
  33. /// </summary>
  34. public bool RequiresInference => _requiredEpoch > Executor.Epoch;
  35. /// <summary>
  36. /// Indicates that this conversation should be sampled.
  37. /// </summary>
  38. public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
  39. #region construction/destruction
  40. internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
  41. {
  42. ConversationId = id;
  43. Executor = batch;
  44. _end = end;
  45. }
  46. /// <summary>
  47. /// Finalizer for Conversation
  48. /// </summary>
  49. ~Conversation()
  50. {
  51. Dispose();
  52. }
  53. /// <summary>
  54. /// End this conversation, freeing all resources used by it
  55. /// </summary>
  56. /// <exception cref="ObjectDisposedException"></exception>
  57. public void Dispose()
  58. {
  59. if (IsDisposed)
  60. return;
  61. _disposed = true;
  62. // Remove this conversation from the KV cache
  63. Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end);
  64. // Prevent finalizer from running
  65. GC.SuppressFinalize(this);
  66. }
  67. private void AssertNotDisposed()
  68. {
  69. if (Executor.IsDisposed)
  70. throw new ObjectDisposedException(nameof(BatchedExecutor));
  71. if (IsDisposed)
  72. throw new ObjectDisposedException(nameof(Conversation));
  73. }
  74. #endregion
  75. /// <summary>
  76. /// Create a copy of the current conversation
  77. /// </summary>
  78. /// <remarks>The copy shares internal state, so consumes very little extra memory.</remarks>
  79. /// <returns></returns>
  80. /// <exception cref="ObjectDisposedException"></exception>
  81. public Conversation Fork()
  82. {
  83. AssertNotDisposed();
  84. if (RequiresInference)
  85. throw new CannotForkWhileRequiresInferenceException();
  86. // Create a new conversation which references the current position in this one
  87. var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
  88. {
  89. _batchIndex = _batchIndex,
  90. _requiredEpoch = _requiredEpoch,
  91. };
  92. // Assign tokens to the new sequence
  93. NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
  94. return c;
  95. }
  96. #region sample
  97. /// <summary>
  98. /// Get the logits from this conversation, ready for sampling
  99. /// </summary>
  100. /// <returns></returns>
  101. /// <exception cref="ObjectDisposedException"></exception>
  102. /// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
  103. /// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
  104. public Span<float> Sample()
  105. {
  106. AssertNotDisposed();
  107. if (_requiredEpoch < Executor.Epoch)
  108. throw new CannotSampleRequiresPromptException();
  109. if (_requiredEpoch > Executor.Epoch)
  110. throw new CannotSampleRequiresInferenceException();
  111. return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
  112. }
  113. #endregion
  114. #region prompt
  115. private void AssertCanBePrompted()
  116. {
  117. AssertNotDisposed();
  118. if (RequiresInference)
  119. throw new AlreadyPromptedConversationException();
  120. }
  121. /// <summary>
  122. /// Add tokens to this conversation
  123. /// </summary>
  124. /// <param name="input"></param>
  125. /// <returns></returns>
  126. public void Prompt(string input)
  127. {
  128. AssertCanBePrompted();
  129. Prompt(Executor.Context.Tokenize(input));
  130. }
  131. /// <summary>
  132. /// Add tokens to this conversation
  133. /// </summary>
  134. /// <param name="tokens"></param>
  135. /// <returns></returns>
  136. /// <exception cref="ObjectDisposedException"></exception>
  137. public void Prompt(IReadOnlyList<LLamaToken> tokens)
  138. {
  139. AssertCanBePrompted();
  140. // No point doing anything if there is no actual prompt!
  141. if (tokens.Count == 0)
  142. return;
  143. // Add the prompt to the batch
  144. for (var i = 0; i < tokens.Count; i++)
  145. _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
  146. // Mark this conversation as needing inference/sampling
  147. _requiredEpoch = Executor.Epoch + 1;
  148. }
  149. /// <summary>
  150. /// Add a single token to this conversation
  151. /// </summary>
  152. /// <param name="token"></param>
  153. /// <returns></returns>
  154. /// <exception cref="ObjectDisposedException"></exception>
  155. /// <exception cref="InvalidOperationException"></exception>
  156. public void Prompt(LLamaToken token)
  157. {
  158. AssertCanBePrompted();
  159. // Add this token as input
  160. _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
  161. // Mark this conversation as needing inference/sampling
  162. _requiredEpoch = Executor.Epoch + 1;
  163. }
  164. #endregion
  165. #region modify
  166. /// <summary>
  167. /// Directly modify the KV cache of this conversation
  168. /// </summary>
  169. /// <param name="modifier"></param>
  170. /// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
  171. public void Modify(ModifyKvCache modifier)
  172. {
  173. AssertNotDisposed();
  174. if (RequiresInference)
  175. throw new CannotModifyWhileRequiresInferenceException();
  176. // do whatever the modification is
  177. _end = modifier.Invoke(_end, new KvAccessor(this));
  178. // Set the epoch down to zero, this ensures that this conversation
  179. // cannot be sampled until it is prompted again.
  180. _requiredEpoch = 0;
  181. }
  182. /// <summary>
  183. /// Provides direct access to the KV cache of a <see cref="Conversation"/>.
  184. /// See <see cref="Modify"/> for how to use this.
  185. /// </summary>
  186. public readonly ref struct KvAccessor
  187. {
  188. private readonly Conversation _conversation;
  189. internal KvAccessor(Conversation conversation)
  190. {
  191. _conversation = conversation;
  192. }
  193. #region remove
  194. /// <summary>
  195. /// Removes all tokens that have positions in [start, end)
  196. /// </summary>
  197. /// <param name="start">Start position (inclusive)</param>
  198. /// <param name="end">End position (exclusive)</param>
  199. public void Remove(LLamaPos start, LLamaPos end)
  200. {
  201. _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
  202. }
  203. /// <summary>
  204. /// Removes all tokens starting from the given position
  205. /// </summary>
  206. /// <param name="start">Start position (inclusive)</param>
  207. /// <param name="count">Number of tokens</param>
  208. public void Remove(LLamaPos start, int count)
  209. {
  210. if (count <= 0)
  211. return;
  212. var end = start.Value + count;
  213. _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
  214. }
  215. #endregion
  216. #region shift
  217. /// <summary>
  218. /// Adds relative position "delta" to all tokens that have positions in [p0, p1).
  219. /// If the KV cache is RoPEd, the KV data is updated
  220. /// accordingly
  221. /// </summary>
  222. /// <param name="start">Start position (inclusive)</param>
  223. /// <param name="end">End position (exclusive)</param>
  224. /// <param name="delta">Amount to add on to each token position</param>
  225. public void Add(LLamaPos start, LLamaPos end, int delta)
  226. {
  227. _conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta);
  228. }
  229. #endregion
  230. #region divide
  231. /// <summary>
  232. /// Integer division of the positions by factor of `d > 1`.
  233. /// If the KV cache is RoPEd, the KV data is updated accordingly.
  234. /// </summary>
  235. /// <param name="start">Start position (inclusive). If less than zero, it is clamped to zero.</param>
  236. /// <param name="end">End position (exclusive). If less than zero, it is treated as "infinity".</param>
  237. /// <param name="divisor">Amount to divide each position by.</param>
  238. public void Divide(LLamaPos start, LLamaPos end, int divisor)
  239. {
  240. if (divisor <= 0)
  241. throw new ArgumentOutOfRangeException(nameof(divisor));
  242. _conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
  243. }
  244. #endregion
  245. }
  246. /// <summary>
  247. /// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly
  248. /// </summary>
  249. /// <param name="end">The current end token of this conversation</param>
  250. /// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param>
  251. /// <returns>The new end token position</returns>
  252. public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
  253. #endregion
  254. }