You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSamplingPipeline.cs 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using LLama.Native;
  5. namespace LLama.Sampling;
  6. /// <summary>
  7. /// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
  8. /// </summary>
  9. public abstract class BaseSamplingPipeline
  10. : ISamplingPipeline
  11. {
  12. private int _savedLogitsCount;
  13. private (LLamaToken index, float logit)[]? _savedLogits;
  14. /// <inheritdoc/>
  15. public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  16. {
  17. var protectedLogits = GetProtectedTokens(ctx);
  18. _savedLogitsCount = protectedLogits.Count;
  19. _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount);
  20. try
  21. {
  22. // Save the values of protected logits
  23. for (var i = 0; i < protectedLogits.Count; i++)
  24. {
  25. var index = protectedLogits[i];
  26. var value = logits[(int)index];
  27. _savedLogits[i] = (index, value);
  28. }
  29. // Process raw logits
  30. ProcessLogits(ctx, logits, lastTokens);
  31. // Automatically restore saved logit values after processing
  32. RestoreProtectedTokens(logits);
  33. // Convert logits into token candidates
  34. var candidates = LLamaTokenDataArray.Create(logits);
  35. // Process token data array
  36. ProcessTokenDataArray(ctx, candidates, lastTokens);
  37. // Choose the final value
  38. return ChooseToken(ctx, candidates);
  39. }
  40. finally
  41. {
  42. ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits);
  43. _savedLogits = null;
  44. _savedLogitsCount = 0;
  45. }
  46. }
  47. #region protected tokens
  48. /// <summary>
  49. /// Get all of the "protected" tokens that cannot be changed by ProcessLogits
  50. /// </summary>
  51. /// <returns></returns>
  52. protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx);
  53. /// <summary>
  54. /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
  55. /// </summary>
  56. /// <param name="logits"></param>
  57. protected void RestoreProtectedTokens(Span<float> logits)
  58. {
  59. if (_savedLogits == null)
  60. return;
  61. // The array may be bigger than necessary, get a span of the valid bit
  62. var saved = _savedLogits.AsSpan(0, _savedLogitsCount);
  63. // Restore the values of protected logits
  64. for (var i = 0; i < saved.Length; i++)
  65. logits[(int)saved[i].index] = saved[i].logit;
  66. }
  67. /// <summary>
  68. /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
  69. /// </summary>
  70. /// <param name="candidates"></param>
  71. protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
  72. {
  73. if (_savedLogits == null || _savedLogits.Length == 0)
  74. return;
  75. candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
  76. }
  77. #endregion
  78. /// <summary>
  79. /// Process the raw logit values
  80. /// </summary>
  81. /// <param name="ctx">The context being sampled from</param>
  82. /// <param name="logits">The logits produced by the model</param>
  83. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  84. protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
  85. /// <summary>
  86. /// Process the LLamaTokenDataArray and select a single token
  87. /// </summary>
  88. /// <param name="ctx">The context being sampled from</param>
  89. /// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
  90. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  91. /// <returns></returns>
  92. protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens);
  93. /// <summary>
  94. /// Choose the final token from the candidates
  95. /// </summary>
  96. /// <param name="ctx"></param>
  97. /// <param name="candidates"></param>
  98. /// <returns></returns>
  99. protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
  100. /// <inheritdoc/>
  101. public virtual void Reset()
  102. {
  103. }
  104. /// <inheritdoc/>
  105. public virtual void Dispose()
  106. {
  107. GC.SuppressFinalize(this);
  108. }
  109. }