You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IContextParamsExtensions.cs 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. using System;
  2. using System.IO;
  3. using LLama.Abstractions;
  4. using LLama.Native;
  5. namespace LLama.Extensions
  6. {
  7. /// <summary>
  8. /// Extention methods to the IContextParams interface
  9. /// </summary>
  10. public static class IContextParamsExtensions
  11. {
  12. /// <summary>
  13. /// Convert the given `IModelParams` into a `LLamaContextParams`
  14. /// </summary>
  15. /// <param name="params"></param>
  16. /// <param name="result"></param>
  17. /// <returns></returns>
  18. /// <exception cref="FileNotFoundException"></exception>
  19. /// <exception cref="ArgumentException"></exception>
  20. public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
  21. {
  22. result = NativeApi.llama_context_default_params();
  23. result.n_ctx = @params.ContextSize;
  24. result.n_batch = @params.BatchSize;
  25. result.seed = @params.Seed;
  26. result.f16_kv = @params.UseFp16Memory;
  27. result.logits_all = @params.Perplexity;
  28. result.embedding = @params.EmbeddingMode;
  29. result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
  30. result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
  31. // Default YaRN values copied from here: https://github.com/ggerganov/llama.cpp/blob/381efbf480959bb6d1e247a8b0c2328f22e350f8/common/common.h#L67
  32. result.yarn_ext_factor = @params.YarnExtrapolationFactor ?? -1f;
  33. result.yarn_attn_factor = @params.YarnAttentionFactor ?? 1f;
  34. result.yarn_beta_fast = @params.YarnBetaFast ?? 32f;
  35. result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
  36. result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
  37. result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
  38. result.mul_mat_q = @params.MulMatQ;
  39. result.n_threads = Threads(@params.Threads);
  40. result.n_threads_batch = Threads(@params.BatchThreads);
  41. }
  42. private static uint Threads(uint? value)
  43. {
  44. if (value is > 0)
  45. return (uint)value;
  46. return (uint)Math.Max(Environment.ProcessorCount / 2, 1);
  47. }
  48. }
  49. }