| @@ -66,6 +66,12 @@ namespace LLama | |||
| /// The mode used by the executor. | |||
| /// </summary> | |||
| public LLamaModel Model => _model; | |||
| /// <summary> | |||
| /// Current "mu" value for mirostate sampling | |||
| /// </summary> | |||
| protected float MirostateMu { get; set; } = float.NaN; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -78,8 +84,6 @@ namespace LLama | |||
| _pastTokensCount = 0; | |||
| _consumedTokensCount = 0; | |||
| _n_session_consumed = 0; | |||
| _embeds = new(); | |||
| _embed_inps = new(); | |||
| _last_n_tokens = new FixedSizeQueue<llama_token>(_model.ContextSize).FillWith(0); | |||
| } | |||
| @@ -359,24 +363,36 @@ namespace LLama | |||
| { | |||
| [JsonPropertyName("n_past")] | |||
| public int PastTokensCount { get; set; } | |||
| [JsonPropertyName("n_consumed")] | |||
| public int ConsumedTokensCount { get; set; } | |||
| [JsonPropertyName("n_session_consumed")] | |||
| public int ConsumedSessionCount { get; set; } | |||
| [JsonPropertyName("n_matching_session_tokens")] | |||
| public int MatchingSessionTokensCount { get; set; } | |||
| [JsonPropertyName("path_session")] | |||
| public string SessionFilePath { get; set; } | |||
| [JsonPropertyName("embd")] | |||
| public List<llama_token> Embeds { get; set; } | |||
| [JsonPropertyName("embd_inps")] | |||
| public List<llama_token> EmbedInps { get; set; } | |||
| [JsonPropertyName("session_tokens")] | |||
| public List<llama_token> SessionTokens { get; set; } | |||
| [JsonPropertyName("last_n_tokens")] | |||
| public llama_token[] LastTokens { get; set; } | |||
| [JsonPropertyName("last_tokens_maximum_count")] | |||
| public int LastTokensCapacity { get; set; } | |||
| [JsonPropertyName("mirostate_mu")] | |||
| public float MirostateMu { get; set; } | |||
| } | |||
| } | |||
| } | |||
| @@ -4,7 +4,6 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| @@ -20,6 +19,7 @@ namespace LLama | |||
| string _instructionPrefix; | |||
| llama_token[] _inp_pfx; | |||
| llama_token[] _inp_sfx; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -51,7 +51,8 @@ namespace LLama | |||
| PastTokensCount = _pastTokensCount, | |||
| SessionFilePath = _pathSession, | |||
| SessionTokens = _session_tokens, | |||
| LastTokensCapacity = _last_n_tokens.Capacity | |||
| LastTokensCapacity = _last_n_tokens.Capacity, | |||
| MirostateMu = MirostateMu | |||
| }; | |||
| return state; | |||
| } | |||
| @@ -214,8 +215,12 @@ namespace LLama | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); | |||
| var mu = MirostateMu; | |||
| var id = _model.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||
| ); | |||
| MirostateMu = mu; | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -4,12 +4,8 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| @@ -21,6 +17,7 @@ namespace LLama | |||
| { | |||
| bool _is_prompt_run = true; | |||
| llama_token[] _llama_token_newline; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -46,7 +43,8 @@ namespace LLama | |||
| PastTokensCount = _pastTokensCount, | |||
| SessionFilePath = _pathSession, | |||
| SessionTokens = _session_tokens, | |||
| LastTokensCapacity = _last_n_tokens.Capacity | |||
| LastTokensCapacity = _last_n_tokens.Capacity, | |||
| MirostateMu = MirostateMu | |||
| }; | |||
| return state; | |||
| } | |||
| @@ -204,8 +202,12 @@ namespace LLama | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); | |||
| var mu = MirostateMu; | |||
| var id = _model.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||
| ); | |||
| MirostateMu = mu; | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -220,6 +220,7 @@ namespace LLama | |||
| /// Perform the sampling. Please don't use it unless you fully know what it does. | |||
| /// </summary> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="mirostat_mu"></param> | |||
| /// <param name="temperature"></param> | |||
| /// <param name="mirostat"></param> | |||
| /// <param name="mirostatTau"></param> | |||
| @@ -229,10 +230,10 @@ namespace LLama | |||
| /// <param name="tfsZ"></param> | |||
| /// <param name="typicalP"></param> | |||
| /// <returns></returns> | |||
| public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, | |||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, | |||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) | |||
| { | |||
| llama_token id = 0; | |||
| llama_token id; | |||
| if (temperature <= 0) | |||
| { | |||
| // Greedy sampling | |||
| @@ -240,16 +241,17 @@ namespace LLama | |||
| } | |||
| else | |||
| { | |||
| if (float.IsNaN(mirostat_mu)) | |||
| mirostat_mu = 2 * mirostatTau; | |||
| if (mirostat == MiroStateType.MiroState) | |||
| { | |||
| float mirostat_mu = 2.0f * mirostatTau; | |||
| const int mirostat_m = 100; | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); | |||
| } | |||
| else if (mirostat == MiroStateType.MiroState2) | |||
| { | |||
| float mirostat_mu = 2.0f * mirostatTau; | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); | |||
| } | |||
| @@ -57,6 +57,7 @@ namespace LLama | |||
| lastTokens.AddRange(tokens); | |||
| n_past += n_prompt_tokens; | |||
| var mu = float.NaN; | |||
| int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(int i = 0; i < max_tokens; i++) | |||
| { | |||
| @@ -70,7 +71,7 @@ namespace LLama | |||
| var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); | |||
| lastTokens.Add(id); | |||