You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaBatchSafeHandle.cs 3.4 kB

2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. using System;
  2. namespace LLama.Native;
  3. using llama_token = Int32;
  4. /// <summary>
  5. /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
  6. /// </summary>
  7. public sealed class LLamaBatchSafeHandle
  8. : SafeLLamaHandleBase
  9. {
  10. private readonly int _embd;
  11. /// <summary>
  12. /// Get the native llama_batch struct
  13. /// </summary>
  14. public LLamaNativeBatch NativeBatch;
  15. /// <summary>
  16. /// the token ids of the input (used when embd is NULL)
  17. /// </summary>
  18. public Span<llama_token> Token
  19. {
  20. get
  21. {
  22. unsafe
  23. {
  24. if (_embd != 0)
  25. return new Span<int>(null, 0);
  26. else
  27. return new Span<int>(NativeBatch.token, NativeBatch.n_tokens);
  28. }
  29. }
  30. }
  31. /// <summary>
  32. /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
  33. /// </summary>
  34. public Span<llama_token> Embed
  35. {
  36. get
  37. {
  38. unsafe
  39. {
  40. // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
  41. // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
  42. if (_embd != 0)
  43. return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
  44. else
  45. return new Span<llama_token>(null, 0);
  46. }
  47. }
  48. }
  49. /// <summary>
  50. /// the positions of the respective token in the sequence
  51. /// </summary>
  52. public Span<LLamaPos> Pos
  53. {
  54. get
  55. {
  56. unsafe
  57. {
  58. return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
  59. }
  60. }
  61. }
  62. /// <summary>
  63. /// the sequence to which the respective token belongs
  64. /// </summary>
  65. public Span<LLamaSeqId> Sequence_ID
  66. {
  67. get
  68. {
  69. unsafe
  70. {
  71. return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
  72. }
  73. }
  74. }
  75. /// <summary>
  76. /// if zero, the logits for the respective token will not be output
  77. /// </summary>
  78. public Span<byte> Logits
  79. {
  80. get
  81. {
  82. unsafe
  83. {
  84. return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
  85. }
  86. }
  87. }
  88. /// <summary>
  89. /// Create a safe handle owning a `LLamaNativeBatch`
  90. /// </summary>
  91. /// <param name="batch"></param>
  92. /// <param name="embd"></param>
  93. public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
  94. : base((nint)1)
  95. {
  96. _embd = embd;
  97. NativeBatch = batch;
  98. }
  99. /// <summary>
  100. /// Call `llama_batch_init` and create a new batch
  101. /// </summary>
  102. /// <param name="n_tokens"></param>
  103. /// <param name="embd"></param>
  104. /// <param name="n_seq_max"></param>
  105. /// <returns></returns>
  106. public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
  107. {
  108. var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
  109. return new LLamaBatchSafeHandle(batch, embd);
  110. }
  111. /// <inheritdoc />
  112. protected override bool ReleaseHandle()
  113. {
  114. NativeApi.llama_batch_free(NativeBatch);
  115. NativeBatch = default;
  116. SetHandle(IntPtr.Zero);
  117. return true;
  118. }
  119. }