You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSamplingPipeline.cs 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. using System;
  2. using LLama.Native;
  3. namespace LLama.Sampling;
  4. /// <summary>
  5. /// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
  6. /// </summary>
  7. public abstract class BaseSamplingPipeline
  8. : ISamplingPipeline
  9. {
  10. /// <summary>
  11. /// Grammar to constrain valid tokens
  12. /// </summary>
  13. public SafeLLamaGrammarHandle? Grammar { get; set; }
  14. /// <inheritdoc/>
  15. public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  16. {
  17. // Apply processing to raw logit values
  18. ProcessLogits(ctx, logits, lastTokens);
  19. // Process token data array to select a final token
  20. var candidates = LLamaTokenDataArray.Create(logits);
  21. candidates.ApplyGrammar(ctx, Grammar);
  22. return ProcessTokenDataArray(ctx, candidates, lastTokens);
  23. }
  24. /// <inheritdoc />
  25. public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
  26. {
  27. Grammar?.AcceptToken(ctx, token);
  28. }
  29. /// <summary>
  30. /// Process the raw logit values
  31. /// </summary>
  32. /// <param name="ctx">The context being sampled from</param>
  33. /// <param name="logits">The logits produced by the model</param>
  34. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  35. protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
  36. /// <summary>
  37. /// Process the LLamaTokenDataArray and select a single token
  38. /// </summary>
  39. /// <param name="ctx">The context being sampled from</param>
  40. /// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
  41. /// <param name="lastTokens">A list of tokens recently returned by the model</param>
  42. /// <returns></returns>
  43. protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens);
  44. /// <inheritdoc/>
  45. public virtual void Reset()
  46. {
  47. }
  48. /// <inheritdoc />
  49. public abstract ISamplingPipeline Clone();
  50. /// <inheritdoc/>
  51. public virtual void Dispose()
  52. {
  53. GC.SuppressFinalize(this);
  54. }
  55. }