You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutor.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. using System;
  2. using System.Threading;
  3. using System.Threading.Tasks;
  4. using LLama.Abstractions;
  5. using LLama.Native;
  6. namespace LLama.Batched;
  7. /// <summary>
  8. /// A batched executor that can infer multiple separate "conversations" simultaneously.
  9. /// </summary>
  10. public sealed class BatchedExecutor
  11. : IDisposable
  12. {
  13. private int _nextSequenceId;
  14. internal LLamaBatch Batch { get; }
  15. /// <summary>
  16. /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
  17. /// whether they're waiting for inference, or can be sampled.
  18. /// </summary>
  19. internal ulong Epoch { get; private set; }
  20. /// <summary>
  21. /// The <see cref="LLamaContext"/> this executor is using
  22. /// </summary>
  23. public LLamaContext Context { get; }
  24. /// <summary>
  25. /// The <see cref="LLamaWeights"/> this executor is using
  26. /// </summary>
  27. public LLamaWeights Model { get; }
  28. /// <summary>
  29. /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
  30. /// </summary>
  31. public int BatchedTokenCount => Batch.TokenCount;
  32. /// <summary>
  33. /// Check if this executor has been disposed.
  34. /// </summary>
  35. public bool IsDisposed { get; private set; }
  36. /// <summary>
  37. /// Create a new batched executor
  38. /// </summary>
  39. /// <param name="model">The model to use</param>
  40. /// <param name="contextParams">Parameters to create a new context</param>
  41. public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
  42. {
  43. Model = model;
  44. Batch = new LLamaBatch();
  45. Context = model.CreateContext(contextParams);
  46. Epoch = 1;
  47. }
  48. /// <summary>
  49. /// Start a new <see cref="Conversation"/> with the given prompt
  50. /// </summary>
  51. /// <param name="prompt"></param>
  52. /// <returns></returns>
  53. [Obsolete("Use BatchedExecutor.Create instead")]
  54. public Conversation Prompt(string prompt)
  55. {
  56. if (IsDisposed)
  57. throw new ObjectDisposedException(nameof(BatchedExecutor));
  58. var conversation = Create();
  59. conversation.Prompt(prompt);
  60. return conversation;
  61. }
  62. /// <summary>
  63. /// Start a new <see cref="Conversation"/>
  64. /// </summary>
  65. /// <returns></returns>
  66. public Conversation Create()
  67. {
  68. if (IsDisposed)
  69. throw new ObjectDisposedException(nameof(BatchedExecutor));
  70. return new Conversation(this, GetNextSequenceId());
  71. }
  72. /// <summary>
  73. /// Load a conversation that was previously saved to a file. Once loaded the conversation will
  74. /// need to be prompted.
  75. /// </summary>
  76. /// <param name="filepath"></param>
  77. /// <returns></returns>
  78. /// <exception cref="ObjectDisposedException"></exception>
  79. public Conversation Load(string filepath)
  80. {
  81. if (IsDisposed)
  82. throw new ObjectDisposedException(nameof(BatchedExecutor));
  83. var conversation = Create();
  84. conversation.Load(filepath);
  85. return conversation;
  86. }
  87. /// <summary>
  88. /// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
  89. /// </summary>
  90. /// <param name="state"></param>
  91. /// <returns></returns>
  92. /// <exception cref="ObjectDisposedException"></exception>
  93. public Conversation Load(Conversation.State state)
  94. {
  95. if (IsDisposed)
  96. throw new ObjectDisposedException(nameof(BatchedExecutor));
  97. var conversation = Create();
  98. conversation.Load(state);
  99. return conversation;
  100. }
  101. /// <summary>
  102. /// Run inference for all conversations in the batch which have pending tokens.
  103. ///
  104. /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
  105. /// threads and running inference again.
  106. /// </summary>
  107. public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
  108. {
  109. if (IsDisposed)
  110. throw new ObjectDisposedException(nameof(BatchedExecutor));
  111. var status = await Context.DecodeAsync(Batch, cancellation);
  112. // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
  113. // be called again after a warning (e.g. NoKvSlot).
  114. if (status == DecodeResult.Ok)
  115. {
  116. Epoch++;
  117. Batch.Clear();
  118. }
  119. return status;
  120. }
  121. /// <inheritdoc />
  122. public void Dispose()
  123. {
  124. if (IsDisposed)
  125. return;
  126. IsDisposed = true;
  127. Context.Dispose();
  128. }
  129. internal LLamaSeqId GetNextSequenceId()
  130. {
  131. return checked((LLamaSeqId)_nextSequenceId++);
  132. }
  133. }