You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ISamplingPipeline.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Runtime.InteropServices;
  5. using LLama.Native;
  6. using LLama.Sampling.Logits;
  7. using LLama.Sampling.Selection;
  8. using LLama.Sampling.Tokens;
  9. namespace LLama.Sampling;
  10. /// <summary>
  11. /// Convert a span of logits into a single sampled token
  12. /// </summary>
  13. public interface ISamplingPipeline
  14. : IDisposable
  15. {
  16. /// <summary>
  17. /// Sample a single token from the given logits
  18. /// </summary>
  19. /// <param name="ctx">The context being sampled from</param>
  20. /// <param name="logits">The logits produced by the model</param>
  21. /// <param name="lastTokens">A span of tokens recently returned by the model</param>
  22. /// <returns></returns>
  23. int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
  24. /// <summary>
  25. /// Reset all internal state of the sampling pipeline
  26. /// </summary>
  27. void Reset();
  28. }
  29. /// <summary>
  30. /// Extensions methods for ISamplingPipeline
  31. /// </summary>
  32. public static class ISamplingPipelineExtensions
  33. {
  34. /// <summary>
  35. /// Sample a single token from the given logits
  36. /// </summary>
  37. /// <param name="pipeline"></param>
  38. /// <param name="ctx">The context being sampled from</param>
  39. /// <param name="logits">The logits produced by the model</param>
  40. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  41. /// <returns></returns>
  42. public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
  43. {
  44. #if NET5_0_OR_GREATER
  45. var span = CollectionsMarshal.AsSpan(lastTokens);
  46. return pipeline.Sample(ctx, logits, span);
  47. #else
  48. var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
  49. try
  50. {
  51. lastTokens.CopyTo(copy);
  52. return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
  53. }
  54. finally
  55. {
  56. ArrayPool<int>.Shared.Return(copy);
  57. }
  58. #endif
  59. }
  60. }
  61. /// <summary>
  62. /// Simple implementation of `ISamplingPipeline`, applies processors in order every time
  63. /// </summary>
  64. public sealed class ConfigurableSamplingPipeline
  65. : ISamplingPipeline
  66. {
  67. /// <summary>
  68. /// Logit processors to apply in this pipeline
  69. /// </summary>
  70. public IList<ILogitProcessor> LogitProcessors { get; } = new List<ILogitProcessor>();
  71. /// <summary>
  72. /// Token data processors to apply in this pipeline
  73. /// </summary>
  74. public IList<ITokenDataProcessor> TokenDataProcessors { get; } = new List<ITokenDataProcessor>();
  75. /// <summary>
  76. /// The selector to choose the final token
  77. /// </summary>
  78. public ITokenSelector Selector { get; set; } = new StandardSelection();
  79. /// <inheritdoc />
  80. public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
  81. {
  82. // Modify raw logits
  83. foreach (var logitProcessor in LogitProcessors)
  84. logitProcessor.ProcessLogits(ctx, logits, lastTokens);
  85. // Convert logits into token candidates
  86. var candidates_p = LLamaTokenDataArray.Create(logits);
  87. // Process token candidates
  88. foreach (var tokenDataProcessor in TokenDataProcessors)
  89. tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens);
  90. // Select a token
  91. var token = Selector.Select(ctx, candidates_p, lastTokens);
  92. // Tell processors what was selected
  93. foreach (var logitProcessor in LogitProcessors)
  94. logitProcessor.AcceptToken(ctx, token);
  95. foreach (var tokenDataProcessor in TokenDataProcessors)
  96. tokenDataProcessor.AcceptToken(ctx, token);
  97. return token;
  98. }
  99. /// <inheritdoc />
  100. public void Reset()
  101. {
  102. foreach (var logitProcessor in LogitProcessors)
  103. logitProcessor.Reset();
  104. foreach (var tokenDataProcessor in TokenDataProcessors)
  105. tokenDataProcessor.Reset();
  106. Selector.Reset();
  107. }
  108. /// <inheritdoc />
  109. public void Dispose()
  110. {
  111. foreach (var logitProcessor in LogitProcessors)
  112. logitProcessor.Dispose();
  113. foreach (var tokenDataProcessor in TokenDataProcessors)
  114. tokenDataProcessor.Dispose();
  115. Selector.Dispose();
  116. }
  117. }