diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 4b0c09b3..afbc0f25 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -66,6 +66,12 @@ namespace LLama
/// The mode used by the executor.
///
public LLamaModel Model => _model;
+
+ ///
+ /// Current "mu" value for mirostate sampling
+ ///
+ protected float MirostateMu { get; set; } = float.NaN;
+
///
///
///
@@ -78,8 +84,6 @@ namespace LLama
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
- _embeds = new();
- _embed_inps = new();
_last_n_tokens = new FixedSizeQueue(_model.ContextSize).FillWith(0);
}
@@ -359,24 +363,36 @@ namespace LLama
{
[JsonPropertyName("n_past")]
public int PastTokensCount { get; set; }
+
[JsonPropertyName("n_consumed")]
public int ConsumedTokensCount { get; set; }
+
[JsonPropertyName("n_session_consumed")]
public int ConsumedSessionCount { get; set; }
+
[JsonPropertyName("n_matching_session_tokens")]
public int MatchingSessionTokensCount { get; set; }
+
[JsonPropertyName("path_session")]
public string SessionFilePath { get; set; }
+
[JsonPropertyName("embd")]
public List Embeds { get; set; }
+
[JsonPropertyName("embd_inps")]
public List EmbedInps { get; set; }
+
[JsonPropertyName("session_tokens")]
public List SessionTokens { get; set; }
+
[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
+
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
+
+ [JsonPropertyName("mirostate_mu")]
+ public float MirostateMu { get; set; }
}
}
}
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 613a5f46..89fbac59 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -4,7 +4,6 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
-using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -20,6 +19,7 @@ namespace LLama
string _instructionPrefix;
llama_token[] _inp_pfx;
llama_token[] _inp_sfx;
+
///
///
///
@@ -51,7 +51,8 @@ namespace LLama
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
- LastTokensCapacity = _last_n_tokens.Capacity
+ LastTokensCapacity = _last_n_tokens.Capacity,
+ MirostateMu = MirostateMu
};
return state;
}
@@ -214,8 +215,12 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
+ var mu = MirostateMu;
+ var id = _model.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
+ );
+ MirostateMu = mu;
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 44448aeb..bc3a242e 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -4,12 +4,8 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
-using System.Threading;
-using System.Threading.Tasks;
namespace LLama
{
@@ -21,6 +17,7 @@ namespace LLama
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;
+
///
///
///
@@ -46,7 +43,8 @@ namespace LLama
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
- LastTokensCapacity = _last_n_tokens.Capacity
+ LastTokensCapacity = _last_n_tokens.Capacity,
+ MirostateMu = MirostateMu
};
return state;
}
@@ -204,8 +202,12 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
+ var mu = MirostateMu;
+ var id = _model.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
+ );
+ MirostateMu = mu;
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index 2f03c008..9f65f16f 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -220,6 +220,7 @@ namespace LLama
/// Perform the sampling. Please don't use it unless you fully know what it does.
///
///
+ ///
///
///
///
@@ -229,10 +230,10 @@ namespace LLama
///
///
///
- public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
- float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
+ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
+ float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
- llama_token id = 0;
+ llama_token id;
if (temperature <= 0)
{
// Greedy sampling
@@ -240,16 +241,17 @@ namespace LLama
}
else
{
+ if (float.IsNaN(mirostat_mu))
+ mirostat_mu = 2 * mirostatTau;
+
if (mirostat == MiroStateType.MiroState)
{
- float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MiroStateType.MiroState2)
{
- float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 0f0ae651..88fa1695 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -57,6 +57,7 @@ namespace LLama
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
+ var mu = float.NaN;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
@@ -70,7 +71,7 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
lastTokens.Add(id);