You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ISamplingPipeline.cs 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Runtime.InteropServices;
  5. using LLama.Native;
  6. namespace LLama.Sampling;
  7. /// <summary>
  8. /// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
  9. /// </summary>
  10. public interface ISamplingPipeline
  11. : IDisposable
  12. {
  13. /// <summary>
  14. /// Sample a single token from the given logits
  15. /// </summary>
  16. /// <param name="ctx">The context being sampled from</param>
  17. /// <param name="logits">The logits produced by the model</param>
  18. /// <param name="lastTokens">A span of tokens recently returned by the model</param>
  19. /// <returns></returns>
  20. int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
  21. /// <summary>
  22. /// Reset all internal state of the sampling pipeline
  23. /// </summary>
  24. void Reset();
  25. }
  26. /// <summary>
  27. /// Extensions methods for ISamplingPipeline
  28. /// </summary>
  29. public static class ISamplingPipelineExtensions
  30. {
  31. /// <summary>
  32. /// Sample a single token from the given logits
  33. /// </summary>
  34. /// <param name="pipeline"></param>
  35. /// <param name="ctx">The context being sampled from</param>
  36. /// <param name="logits">The logits produced by the model</param>
  37. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  38. /// <returns></returns>
  39. public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
  40. {
  41. #if NET5_0_OR_GREATER
  42. var span = CollectionsMarshal.AsSpan(lastTokens);
  43. return pipeline.Sample(ctx, logits, span);
  44. #else
  45. var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
  46. try
  47. {
  48. lastTokens.CopyTo(copy);
  49. return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
  50. }
  51. finally
  52. {
  53. ArrayPool<int>.Shared.Return(copy);
  54. }
  55. #endif
  56. }
  57. }