diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 4b0c09b3..afbc0f25 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -66,6 +66,12 @@ namespace LLama /// The mode used by the executor. /// public LLamaModel Model => _model; + + /// + /// Current "mu" value for mirostate sampling + /// + protected float MirostateMu { get; set; } = float.NaN; + /// /// /// @@ -78,8 +84,6 @@ namespace LLama _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; - _embeds = new(); - _embed_inps = new(); _last_n_tokens = new FixedSizeQueue(_model.ContextSize).FillWith(0); } @@ -359,24 +363,36 @@ namespace LLama { [JsonPropertyName("n_past")] public int PastTokensCount { get; set; } + [JsonPropertyName("n_consumed")] public int ConsumedTokensCount { get; set; } + [JsonPropertyName("n_session_consumed")] public int ConsumedSessionCount { get; set; } + [JsonPropertyName("n_matching_session_tokens")] public int MatchingSessionTokensCount { get; set; } + [JsonPropertyName("path_session")] public string SessionFilePath { get; set; } + [JsonPropertyName("embd")] public List Embeds { get; set; } + [JsonPropertyName("embd_inps")] public List EmbedInps { get; set; } + [JsonPropertyName("session_tokens")] public List SessionTokens { get; set; } + [JsonPropertyName("last_n_tokens")] public llama_token[] LastTokens { get; set; } + [JsonPropertyName("last_tokens_maximum_count")] public int LastTokensCapacity { get; set; } + + [JsonPropertyName("mirostate_mu")] + public float MirostateMu { get; set; } } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 613a5f46..89fbac59 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; -using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -20,6 +19,7 @@ namespace LLama string _instructionPrefix; llama_token[] _inp_pfx; llama_token[] _inp_sfx; + /// /// /// @@ -51,7 +51,8 @@ namespace LLama PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, SessionTokens = _session_tokens, - LastTokensCapacity = _last_n_tokens.Capacity + LastTokensCapacity = _last_n_tokens.Capacity, + MirostateMu = MirostateMu }; return state; } @@ -214,8 +215,12 @@ namespace LLama var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); + var mu = MirostateMu; + var id = _model.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP + ); + MirostateMu = mu; _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 44448aeb..bc3a242e 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -4,12 +4,8 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; -using System.Runtime.CompilerServices; -using System.Text; using System.Text.Json; using System.Text.Json.Serialization; -using System.Threading; -using System.Threading.Tasks; namespace LLama { @@ -21,6 +17,7 @@ namespace LLama { bool _is_prompt_run = true; llama_token[] _llama_token_newline; + /// /// /// @@ -46,7 +43,8 @@ namespace LLama PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, SessionTokens = _session_tokens, - LastTokensCapacity = _last_n_tokens.Capacity + LastTokensCapacity = _last_n_tokens.Capacity, + MirostateMu = MirostateMu }; return state; } @@ -204,8 +202,12 @@ namespace LLama var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); + var mu = MirostateMu; + var id = _model.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP + ); + MirostateMu = mu; _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 2f03c008..9f65f16f 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -220,6 +220,7 @@ namespace LLama /// Perform the sampling. Please don't use it unless you fully know what it does. /// /// + /// /// /// /// @@ -229,10 +230,10 @@ namespace LLama /// /// /// - public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, - float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) + public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, + float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) { - llama_token id = 0; + llama_token id; if (temperature <= 0) { // Greedy sampling @@ -240,16 +241,17 @@ namespace LLama } else { + if (float.IsNaN(mirostat_mu)) + mirostat_mu = 2 * mirostatTau; + if (mirostat == MiroStateType.MiroState) { - float mirostat_mu = 2.0f * mirostatTau; const int mirostat_m = 100; SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); } else if (mirostat == MiroStateType.MiroState2) { - float mirostat_mu = 2.0f * mirostatTau; SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 0f0ae651..88fa1695 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -57,6 +57,7 @@ namespace LLama lastTokens.AddRange(tokens); n_past += n_prompt_tokens; + var mu = float.NaN; int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(int i = 0; i < max_tokens; i++) { @@ -70,7 +71,7 @@ namespace LLama var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); lastTokens.Add(id);