|
|
|
@@ -229,7 +229,7 @@ namespace LLama |
|
|
|
/// <param name="tfsZ"></param> |
|
|
|
/// <param name="typicalP"></param> |
|
|
|
/// <returns></returns> |
|
|
|
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, |
|
|
|
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, |
|
|
|
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) |
|
|
|
{ |
|
|
|
llama_token id = 0; |
|
|
|
@@ -240,14 +240,14 @@ namespace LLama |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
if (mirostat == MiroStateType.MiroState) |
|
|
|
if (mirostat == MirostatType.Mirostat) |
|
|
|
{ |
|
|
|
float mirostat_mu = 2.0f * mirostatTau; |
|
|
|
const int mirostat_m = 100; |
|
|
|
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); |
|
|
|
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); |
|
|
|
} |
|
|
|
else if (mirostat == MiroStateType.MiroState2) |
|
|
|
else if (mirostat == MirostatType.Mirostat2) |
|
|
|
{ |
|
|
|
float mirostat_mu = 2.0f * mirostatTau; |
|
|
|
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); |
|
|
|
|