diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index 5935a0ee..77af7eaf 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
-using System.Text;
namespace LLama.Common
{
@@ -83,7 +82,7 @@ namespace LLama.Common
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
///
- public MiroStateType Mirostat { get; set; } = MiroStateType.Disable;
+ public MirostatType Mirostat { get; set; } = MirostatType.Disable;
///
/// target entropy
///
@@ -98,10 +97,25 @@ namespace LLama.Common
public bool PenalizeNL { get; set; } = true;
}
- public enum MiroStateType
+ ///
+ /// Type of "mirostat" sampling to use.
+ /// https://github.com/basusourya/mirostat
+ ///
+ public enum MirostatType
{
+ ///
+ /// Disable Mirostat sampling
+ ///
Disable = 0,
- MiroState = 1,
- MiroState2 = 2
+
+ ///
+ /// Original mirostat algorithm
+ ///
+ Mirostat = 1,
+
+ ///
+ /// Mirostat 2.0 algorithm
+ ///
+ Mirostat2 = 2
}
}
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index f004a782..ec8965e8 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -229,7 +229,7 @@ namespace LLama
///
///
///
- public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
+ public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id = 0;
@@ -240,14 +240,14 @@ namespace LLama
}
else
{
- if (mirostat == MiroStateType.MiroState)
+ if (mirostat == MirostatType.Mirostat)
{
float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
- else if (mirostat == MiroStateType.MiroState2)
+ else if (mirostat == MirostatType.Mirostat2)
{
float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);