You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaBatchSafeHandle.cs 4.4 kB

2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. using System;
  2. namespace LLama.Native;
  3. /// <summary>
  4. /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
  5. /// </summary>
  6. public sealed class LLamaBatchSafeHandle
  7. : SafeLLamaHandleBase
  8. {
  9. private readonly int _embd;
  10. /// <summary>
  11. /// Get the native llama_batch struct
  12. /// </summary>
  13. public LLamaNativeBatch NativeBatch;
  14. /// <summary>
  15. /// the token ids of the input (used when embd is NULL)
  16. /// </summary>
  17. public Span<LLamaToken> Token
  18. {
  19. get
  20. {
  21. unsafe
  22. {
  23. if (_embd != 0)
  24. return new Span<LLamaToken>(null, 0);
  25. else
  26. return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens);
  27. }
  28. }
  29. }
  30. /// <summary>
  31. /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
  32. /// </summary>
  33. public Span<LLamaToken> Embed
  34. {
  35. get
  36. {
  37. unsafe
  38. {
  39. // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
  40. // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
  41. if (_embd != 0)
  42. return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
  43. else
  44. return new Span<LLamaToken>(null, 0);
  45. }
  46. }
  47. }
  48. /// <summary>
  49. /// the positions of the respective token in the sequence
  50. /// </summary>
  51. public Span<LLamaPos> Pos
  52. {
  53. get
  54. {
  55. unsafe
  56. {
  57. return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
  58. }
  59. }
  60. }
  61. /// <summary>
  62. /// the sequence to which the respective token belongs
  63. /// </summary>
  64. public Span<LLamaSeqId> Sequence_ID
  65. {
  66. get
  67. {
  68. unsafe
  69. {
  70. return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
  71. }
  72. }
  73. }
  74. /// <summary>
  75. /// if zero, the logits for the respective token will not be output
  76. /// </summary>
  77. public Span<byte> Logits
  78. {
  79. get
  80. {
  81. unsafe
  82. {
  83. return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
  84. }
  85. }
  86. }
  87. /// <summary>
  88. /// Create a safe handle owning a `LLamaNativeBatch`
  89. /// </summary>
  90. /// <param name="batch"></param>
  91. /// <param name="embd"></param>
  92. public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
  93. : base((nint)1)
  94. {
  95. _embd = embd;
  96. NativeBatch = batch;
  97. }
  98. /// <summary>
  99. /// Call `llama_batch_init` and create a new batch
  100. /// </summary>
  101. /// <param name="n_tokens"></param>
  102. /// <param name="embd"></param>
  103. /// <param name="n_seq_max"></param>
  104. /// <returns></returns>
  105. public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
  106. {
  107. var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
  108. return new LLamaBatchSafeHandle(batch, embd);
  109. }
  110. /// <inheritdoc />
  111. protected override bool ReleaseHandle()
  112. {
  113. NativeApi.llama_batch_free(NativeBatch);
  114. NativeBatch = default;
  115. SetHandle(IntPtr.Zero);
  116. return true;
  117. }
  118. /// <summary>
  119. /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
  120. /// </summary>
  121. public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
  122. {
  123. unsafe
  124. {
  125. NativeBatch.token[NativeBatch.n_tokens] = token;
  126. NativeBatch.pos[NativeBatch.n_tokens] = pos;
  127. NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;
  128. for (var i = 0; i < sequences.Length; i++)
  129. NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];
  130. NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);
  131. NativeBatch.n_tokens++;
  132. }
  133. }
  134. /// <summary>
  135. /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
  136. /// </summary>
  137. public void LLamaBatchClear()
  138. {
  139. NativeBatch.n_tokens = 0;
  140. }
  141. }