You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaBatchSafeHandle.cs 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using System;
  2. namespace LLama.Native;
  3. using llama_token = Int32;
  4. public sealed class LLamaBatchSafeHandle
  5. : SafeLLamaHandleBase
  6. {
  7. private readonly int _embd;
  8. public LLamaNativeBatch Batch { get; private set; }
  9. /// <summary>
  10. /// the token ids of the input (used when embd is NULL)
  11. /// </summary>
  12. public Span<llama_token> Token
  13. {
  14. get
  15. {
  16. unsafe
  17. {
  18. if (_embd != 0)
  19. return new Span<int>(null, 0);
  20. else
  21. return new Span<int>(Batch.token, Batch.n_tokens);
  22. }
  23. }
  24. }
  25. /// <summary>
  26. /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
  27. /// </summary>
  28. public Span<llama_token> Embed
  29. {
  30. get
  31. {
  32. unsafe
  33. {
  34. // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
  35. /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
  36. if (_embd != 0)
  37. return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd);
  38. else
  39. return new Span<llama_token>(null, 0);
  40. }
  41. }
  42. }
  43. /// <summary>
  44. /// the positions of the respective token in the sequence
  45. /// </summary>
  46. public Span<LLamaPos> Pos
  47. {
  48. get
  49. {
  50. unsafe
  51. {
  52. return new Span<LLamaPos>(Batch.pos, Batch.n_tokens);
  53. }
  54. }
  55. }
  56. /// <summary>
  57. /// the sequence to which the respective token belongs
  58. /// </summary>
  59. public Span<LLamaSeqId> Sequence_ID
  60. {
  61. get
  62. {
  63. unsafe
  64. {
  65. return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens);
  66. }
  67. }
  68. }
  69. /// <summary>
  70. /// if zero, the logits for the respective token will not be output
  71. /// </summary>
  72. public Span<byte> Logits
  73. {
  74. get
  75. {
  76. unsafe
  77. {
  78. return new Span<byte>(Batch.logits, Batch.n_tokens);
  79. }
  80. }
  81. }
  82. public LLamaBatchSafeHandle(int n_tokens, int embd)
  83. : base((nint)1)
  84. {
  85. _embd = embd;
  86. Batch = NativeApi.llama_batch_init(n_tokens, embd);
  87. }
  88. protected override bool ReleaseHandle()
  89. {
  90. NativeApi.llama_batch_free(Batch);
  91. Batch = default;
  92. SetHandle(IntPtr.Zero);
  93. return true;
  94. }
  95. }