You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutor.cs 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using System;
  2. using System.Threading;
  3. using System.Threading.Tasks;
  4. using LLama.Abstractions;
  5. using LLama.Native;
  6. namespace LLama.Batched;
  7. /// <summary>
  8. /// A batched executor that can infer multiple separate "conversations" simultaneously.
  9. /// </summary>
  10. public sealed class BatchedExecutor
  11. : IDisposable
  12. {
  13. private int _nextSequenceId;
  14. internal LLamaBatch Batch { get; }
  15. /// <summary>
  16. /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
  17. /// whether they're waiting for inference, or can be sampled.
  18. /// </summary>
  19. internal ulong Epoch { get; private set; }
  20. /// <summary>
  21. /// The <see cref="LLamaContext"/> this executor is using
  22. /// </summary>
  23. public LLamaContext Context { get; }
  24. /// <summary>
  25. /// The <see cref="LLamaWeights"/> this executor is using
  26. /// </summary>
  27. public LLamaWeights Model { get; }
  28. /// <summary>
  29. /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
  30. /// </summary>
  31. public int BatchedTokenCount => Batch.TokenCount;
  32. /// <summary>
  33. /// Check if this executor has been disposed.
  34. /// </summary>
  35. public bool IsDisposed { get; private set; }
  36. /// <summary>
  37. /// Create a new batched executor
  38. /// </summary>
  39. /// <param name="model">The model to use</param>
  40. /// <param name="contextParams">Parameters to create a new context</param>
  41. public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
  42. {
  43. Model = model;
  44. Batch = new LLamaBatch();
  45. Context = model.CreateContext(contextParams);
  46. Epoch = 1;
  47. }
  48. /// <summary>
  49. /// Finalizer for BatchedExecutor
  50. /// </summary>
  51. ~BatchedExecutor()
  52. {
  53. Dispose();
  54. }
  55. /// <summary>
  56. /// Start a new <see cref="Conversation"/> with the given prompt
  57. /// </summary>
  58. /// <param name="prompt"></param>
  59. /// <returns></returns>
  60. public Conversation Prompt(string prompt)
  61. {
  62. if (IsDisposed)
  63. throw new ObjectDisposedException(nameof(BatchedExecutor));
  64. var conversation = new Conversation(this, GetNextSequenceId(), 0);
  65. conversation.Prompt(prompt);
  66. return conversation;
  67. }
  68. /// <summary>
  69. /// Run inference for all conversations in the batch which have pending tokens.
  70. ///
  71. /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
  72. /// threads and running inference again.
  73. /// </summary>
  74. public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
  75. {
  76. if (IsDisposed)
  77. throw new ObjectDisposedException(nameof(BatchedExecutor));
  78. var status = await Context.DecodeAsync(Batch, cancellation);
  79. // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
  80. // be called again after a warning (e.g. NoKvSlot).
  81. if (status == DecodeResult.Ok)
  82. {
  83. Epoch++;
  84. Batch.Clear();
  85. }
  86. return status;
  87. }
  88. /// <inheritdoc />
  89. public void Dispose()
  90. {
  91. if (IsDisposed)
  92. return;
  93. IsDisposed = true;
  94. GC.SuppressFinalize(this);
  95. Context.Dispose();
  96. }
  97. internal LLamaSeqId GetNextSequenceId()
  98. {
  99. return checked((LLamaSeqId)_nextSequenceId++);
  100. }
  101. }