You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ISamplingPipeline.cs 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Runtime.InteropServices;
  5. using LLama.Native;
  6. namespace LLama.Sampling;
  7. /// <summary>
  8. /// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
  9. /// </summary>
  10. public interface ISamplingPipeline
  11. : IDisposable
  12. {
  13. /// <summary>
  14. /// Sample a single token from the given logits
  15. /// </summary>
  16. /// <param name="ctx">The context being sampled from</param>
  17. /// <param name="logits">The logits produced by the model</param>
  18. /// <param name="lastTokens">A span of tokens recently returned by the model</param>
  19. /// <returns></returns>
  20. LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
  21. /// <summary>
  22. /// Update the pipeline, with knowledge that a particular token was just accepted
  23. /// </summary>
  24. /// <param name="ctx"></param>
  25. /// <param name="token"></param>
  26. void Accept(SafeLLamaContextHandle ctx, LLamaToken token);
  27. /// <summary>
  28. /// Reset all internal state of the sampling pipeline
  29. /// </summary>
  30. void Reset();
  31. /// <summary>
  32. /// Create a copy of this sampling pipeline
  33. /// </summary>
  34. /// <returns></returns>
  35. ISamplingPipeline Clone();
  36. }
  37. /// <summary>
  38. /// Extensions methods for ISamplingPipeline
  39. /// </summary>
  40. public static class ISamplingPipelineExtensions
  41. {
  42. /// <summary>
  43. /// Sample a single token from the given logits
  44. /// </summary>
  45. /// <param name="pipeline"></param>
  46. /// <param name="ctx">The context being sampled from</param>
  47. /// <param name="logits">The logits produced by the model</param>
  48. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  49. /// <returns></returns>
  50. public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens)
  51. {
  52. #if NET5_0_OR_GREATER
  53. var span = CollectionsMarshal.AsSpan(lastTokens);
  54. return pipeline.Sample(ctx, logits, span);
  55. #else
  56. var copy = ArrayPool<LLamaToken>.Shared.Rent(lastTokens.Count);
  57. try
  58. {
  59. lastTokens.CopyTo(copy);
  60. return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
  61. }
  62. finally
  63. {
  64. ArrayPool<LLamaToken>.Shared.Return(copy);
  65. }
  66. #endif
  67. }
  68. }