You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaStatelessExecutor.cs 7.0 kB

1 year ago
April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using System.Threading;
  8. using LLama.Exceptions;
  9. using LLama.Native;
  10. using LLama.Sampling;
  11. using Microsoft.Extensions.Logging;
  12. namespace LLama
  13. {
  14. /// <summary>
  15. /// This executor infer the input as one-time job. Previous inputs won't impact on the
  16. /// response to current input.
  17. /// </summary>
  18. public class StatelessExecutor
  19. : ILLamaExecutor
  20. {
  21. private readonly LLamaWeights _weights;
  22. private readonly IContextParams _params;
  23. private readonly ILogger? _logger;
  24. private readonly LLamaBatch _batch;
  25. // LLava Section
  26. public bool IsMultiModal => false;
  27. /// <inheritdoc />
  28. public bool MultiModalProject { get; }
  29. /// <inheritdoc />
  30. public LLavaWeights? ClipModel { get; }
  31. /// <inheritdoc />
  32. public List<byte[]> Images { get; set; }
  33. /// <summary>
  34. /// The context used by the executor when running the inference.
  35. /// </summary>
  36. public LLamaContext Context { get; private set; }
  37. /// <summary>
  38. /// Create a new stateless executor which will use the given model
  39. /// </summary>
  40. /// <param name="weights"></param>
  41. /// <param name="params"></param>
  42. /// <param name="logger"></param>
  43. public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
  44. {
  45. Images = new List<byte[]>();
  46. _weights = weights;
  47. _params = @params;
  48. _logger = logger;
  49. _batch = new LLamaBatch();
  50. Context = _weights.CreateContext(_params, logger);
  51. Context.Dispose();
  52. }
  53. /// <inheritdoc />
  54. public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  55. {
  56. // Ensure the context from last time is disposed (it always should be)
  57. if (!Context.NativeHandle.IsClosed)
  58. Context.Dispose();
  59. // Create an inference context which will be disposed when this method exits
  60. using var context = _weights.CreateContext(_params, _logger);
  61. Context = context;
  62. // Reset the sampling pipeline (if there is one)
  63. inferenceParams?.SamplingPipeline?.Reset();
  64. // Sanity check inference params
  65. inferenceParams ??= new InferenceParams();
  66. if (inferenceParams.TokensKeep > Context.ContextSize)
  67. throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
  68. // Create decoders for the token stream
  69. var decoder = new StreamingTokenDecoder(Context);
  70. var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
  71. // Keep track of the last N tokens emitted
  72. var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
  73. var lastTokens = new List<LLamaToken>(repeat_last_n);
  74. for (var i = 0; i < repeat_last_n; i++)
  75. lastTokens.Add(0);
  76. // Tokenize the prompt
  77. var tokens = Context.Tokenize(prompt, special: true).ToList();
  78. lastTokens.AddRange(tokens);
  79. // Evaluate the prompt, in chunks smaller than the max batch size
  80. var n_past = 0;
  81. var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past);
  82. if (r != DecodeResult.Ok)
  83. throw new LLamaDecodeError(r);
  84. // Begin loop, evaluating one token at a time
  85. var mu = (float?)null;
  86. var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
  87. for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
  88. {
  89. LLamaToken id;
  90. if (inferenceParams.SamplingPipeline is not null)
  91. {
  92. id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
  93. }
  94. else
  95. {
  96. // Penalize the generated tokens by various penalties
  97. var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
  98. inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
  99. // Sample a single token
  100. id = Context.Sample(
  101. tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
  102. inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
  103. inferenceParams.MinP
  104. );
  105. }
  106. // Check if this is the EOS token
  107. if (id == _weights.Tokens.EOS)
  108. break;
  109. // Decode this token into text
  110. decoder.Add(id);
  111. var decoded = decoder.Read();
  112. yield return decoded;
  113. // Check if any of the antiprompts have been generated
  114. if (antiprocessor.Add(decoded))
  115. break;
  116. lastTokens.Add(id);
  117. tokens.Clear();
  118. tokens.Add(id);
  119. // when run out of context
  120. // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
  121. if (n_past + tokens.Count >= Context.ContextSize)
  122. {
  123. var n_left = n_past - inferenceParams.TokensKeep - 1;
  124. var n_discard = n_left / 2;
  125. NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
  126. NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
  127. n_past -= n_discard;
  128. }
  129. // Evaluate with this new token
  130. _batch.Clear();
  131. _batch.Add(id, n_past++, LLamaSeqId.Zero, true);
  132. var returnCode = await context.DecodeAsync(_batch, cancellationToken);
  133. if (returnCode != 0)
  134. throw new LLamaDecodeError(returnCode);
  135. }
  136. }
  137. }
  138. }