diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 5935a0ee..77af7eaf 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Text; namespace LLama.Common { @@ -83,7 +82,7 @@ namespace LLama.Common /// algorithm described in the paper https://arxiv.org/abs/2007.14966. /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 /// - public MiroStateType Mirostat { get; set; } = MiroStateType.Disable; + public MirostatType Mirostat { get; set; } = MirostatType.Disable; /// /// target entropy /// @@ -98,10 +97,25 @@ namespace LLama.Common public bool PenalizeNL { get; set; } = true; } - public enum MiroStateType + /// + /// Type of "mirostat" sampling to use. + /// https://github.com/basusourya/mirostat + /// + public enum MirostatType { + /// + /// Disable Mirostat sampling + /// Disable = 0, - MiroState = 1, - MiroState2 = 2 + + /// + /// Original mirostat algorithm + /// + Mirostat = 1, + + /// + /// Mirostat 2.0 algorithm + /// + Mirostat2 = 2 } } diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index f004a782..ec8965e8 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -229,7 +229,7 @@ namespace LLama /// /// /// - public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, + public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) { llama_token id = 0; @@ -240,14 +240,14 @@ namespace LLama } else { - if (mirostat == MiroStateType.MiroState) + if (mirostat == MirostatType.Mirostat) { float mirostat_mu = 2.0f * mirostatTau; const int mirostat_m = 100; SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); } - else if (mirostat == MiroStateType.MiroState2) + else if (mirostat == MirostatType.Mirostat2) { float mirostat_mu = 2.0f * mirostatTau; SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);