You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutor.cs 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. using System;
  2. using System.Threading;
  3. using System.Threading.Tasks;
  4. using LLama.Abstractions;
  5. using LLama.Native;
  6. namespace LLama.Batched;
  7. /// <summary>
  8. /// A batched executor that can infer multiple separate "conversations" simultaneously.
  9. /// </summary>
  10. public sealed class BatchedExecutor
  11. : IDisposable
  12. {
  13. private int _nextSequenceId;
  14. internal LLamaBatch Batch { get; }
  15. /// <summary>
  16. /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
  17. /// whether they're waiting for inference, or can be sampled.
  18. /// </summary>
  19. internal ulong Epoch { get; private set; }
  20. /// <summary>
  21. /// The <see cref="LLamaContext"/> this executor is using
  22. /// </summary>
  23. public LLamaContext Context { get; }
  24. /// <summary>
  25. /// The <see cref="LLamaWeights"/> this executor is using
  26. /// </summary>
  27. public LLamaWeights Model { get; }
  28. /// <summary>
  29. /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
  30. /// </summary>
  31. public int BatchedTokenCount => Batch.TokenCount;
  32. /// <summary>
  33. /// Check if this executor has been disposed.
  34. /// </summary>
  35. public bool IsDisposed { get; private set; }
  36. /// <summary>
  37. /// Create a new batched executor
  38. /// </summary>
  39. /// <param name="model">The model to use</param>
  40. /// <param name="contextParams">Parameters to create a new context</param>
  41. public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
  42. {
  43. Model = model;
  44. Batch = new LLamaBatch();
  45. Context = model.CreateContext(contextParams);
  46. Epoch = 1;
  47. }
  48. ~BatchedExecutor()
  49. {
  50. Dispose();
  51. }
  52. /// <summary>
  53. /// Start a new <see cref="Conversation"/> with the given prompt
  54. /// </summary>
  55. /// <param name="prompt"></param>
  56. /// <returns></returns>
  57. public Conversation Prompt(string prompt)
  58. {
  59. if (IsDisposed)
  60. throw new ObjectDisposedException(nameof(BatchedExecutor));
  61. var conversation = new Conversation(this, GetNextSequenceId(), 0);
  62. conversation.Prompt(prompt);
  63. return conversation;
  64. }
  65. /// <summary>
  66. /// Run inference for all conversations in the batch which have pending tokens.
  67. ///
  68. /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
  69. /// threads and running inference again.
  70. /// </summary>
  71. public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
  72. {
  73. if (IsDisposed)
  74. throw new ObjectDisposedException(nameof(BatchedExecutor));
  75. var status = await Context.DecodeAsync(Batch, cancellation);
  76. // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
  77. // be called again after a warning (e.g. NoKvSlot).
  78. if (status == DecodeResult.Ok)
  79. {
  80. Epoch++;
  81. Batch.Clear();
  82. }
  83. return status;
  84. }
  85. /// <inheritdoc />
  86. public void Dispose()
  87. {
  88. if (IsDisposed)
  89. return;
  90. IsDisposed = true;
  91. GC.SuppressFinalize(this);
  92. Context.Dispose();
  93. }
  94. internal LLamaSeqId GetNextSequenceId()
  95. {
  96. return checked((LLamaSeqId)_nextSequenceId++);
  97. }
  98. }