You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutor.cs 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. using System;
  2. using System.Threading;
  3. using System.Threading.Tasks;
  4. using LLama.Abstractions;
  5. using LLama.Native;
  6. namespace LLama.Batched;
  7. /// <summary>
  8. /// A batched executor that can infer multiple separate "conversations" simultaneously.
  9. /// </summary>
  10. public sealed class BatchedExecutor
  11. : IDisposable
  12. {
  13. private int _nextSequenceId;
  14. internal LLamaBatch Batch { get; }
  15. /// <summary>
  16. /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
  17. /// whether they're waiting for inference, or can be sampled.
  18. /// </summary>
  19. internal ulong Epoch { get; private set; }
  20. /// <summary>
  21. /// The <see cref="LLamaContext"/> this executor is using
  22. /// </summary>
  23. public LLamaContext Context { get; }
  24. /// <summary>
  25. /// The <see cref="LLamaWeights"/> this executor is using
  26. /// </summary>
  27. public LLamaWeights Model { get; }
  28. /// <summary>
  29. /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
  30. /// </summary>
  31. public int BatchedTokenCount => Batch.TokenCount;
  32. /// <summary>
  33. /// Check if this executor has been disposed.
  34. /// </summary>
  35. public bool IsDisposed { get; private set; }
  36. /// <summary>
  37. /// Create a new batched executor
  38. /// </summary>
  39. /// <param name="model">The model to use</param>
  40. /// <param name="contextParams">Parameters to create a new context</param>
  41. public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
  42. {
  43. Model = model;
  44. Batch = new LLamaBatch();
  45. Context = model.CreateContext(contextParams);
  46. Epoch = 1;
  47. }
  48. /// <summary>
  49. /// Start a new <see cref="Conversation"/>
  50. /// </summary>
  51. /// <returns></returns>
  52. public Conversation Create()
  53. {
  54. if (IsDisposed)
  55. throw new ObjectDisposedException(nameof(BatchedExecutor));
  56. return new Conversation(this, GetNextSequenceId());
  57. }
  58. /// <summary>
  59. /// Load a conversation that was previously saved to a file. Once loaded the conversation will
  60. /// need to be prompted.
  61. /// </summary>
  62. /// <param name="filepath"></param>
  63. /// <returns></returns>
  64. /// <exception cref="ObjectDisposedException"></exception>
  65. public Conversation Load(string filepath)
  66. {
  67. if (IsDisposed)
  68. throw new ObjectDisposedException(nameof(BatchedExecutor));
  69. var conversation = Create();
  70. conversation.Load(filepath);
  71. return conversation;
  72. }
  73. /// <summary>
  74. /// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
  75. /// </summary>
  76. /// <param name="state"></param>
  77. /// <returns></returns>
  78. /// <exception cref="ObjectDisposedException"></exception>
  79. public Conversation Load(Conversation.State state)
  80. {
  81. if (IsDisposed)
  82. throw new ObjectDisposedException(nameof(BatchedExecutor));
  83. var conversation = Create();
  84. conversation.Load(state);
  85. return conversation;
  86. }
  87. /// <summary>
  88. /// Run inference for all conversations in the batch which have pending tokens.
  89. ///
  90. /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
  91. /// threads and running inference again.
  92. /// </summary>
  93. public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
  94. {
  95. if (IsDisposed)
  96. throw new ObjectDisposedException(nameof(BatchedExecutor));
  97. var status = await Context.DecodeAsync(Batch, cancellation);
  98. // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
  99. // be called again after a warning (e.g. NoKvSlot).
  100. if (status == DecodeResult.Ok)
  101. {
  102. Epoch++;
  103. Batch.Clear();
  104. }
  105. return status;
  106. }
  107. /// <inheritdoc />
  108. public void Dispose()
  109. {
  110. if (IsDisposed)
  111. return;
  112. IsDisposed = true;
  113. Context.Dispose();
  114. }
  115. internal LLamaSeqId GetNextSequenceId()
  116. {
  117. return checked((LLamaSeqId)_nextSequenceId++);
  118. }
  119. }