You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Mirostat2SamplingPipeline.cs 1.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. using System;
  2. using LLama.Native;
  3. namespace LLama.Sampling;
  4. /// <summary>
  5. /// A sampling pipeline which uses mirostat (v2) to select tokens
  6. /// </summary>
  7. public class Mirostate2SamplingPipeline
  8. : BaseSamplingPipeline
  9. {
  10. private const float DEFAULT_TAU = 5;
  11. private float _mu = DEFAULT_TAU * 2;
  12. /// <summary>
  13. /// Currently learned mu value
  14. /// </summary>
  15. public float Mu => _mu;
  16. private float _tau = DEFAULT_TAU;
  17. /// <summary>
  18. /// target entropy
  19. /// </summary>
  20. public float Tau
  21. {
  22. get => _tau;
  23. set
  24. {
  25. _tau = value;
  26. _mu = value * 2;
  27. }
  28. }
  29. /// <summary>
  30. /// learning rate
  31. /// </summary>
  32. public float Eta { get; set; } = 0.1f;
  33. /// <inheritdoc />
  34. protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  35. {
  36. return logits;
  37. }
  38. /// <inheritdoc />
  39. protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
  40. {
  41. return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu);
  42. }
  43. /// <inheritdoc />
  44. public override void Reset()
  45. {
  46. base.Reset();
  47. _mu = Tau * 2;
  48. }
  49. /// <inheritdoc />
  50. public override ISamplingPipeline Clone()
  51. {
  52. return new Mirostate2SamplingPipeline
  53. {
  54. Grammar = Grammar?.Clone(),
  55. _mu = _mu,
  56. _tau = _tau,
  57. Eta = Eta
  58. };
  59. }
  60. }