Browse Source

Fixed spelling of "mirostat" instead of "mirostate"

tags/v0.4.2-preview
Martin Evans 2 years ago
parent
commit
36735f7908
2 changed files with 22 additions and 8 deletions
  1. +19
    -5
      LLama/Common/InferenceParams.cs
  2. +3
    -3
      LLama/LLamaModel.cs

+ 19
- 5
LLama/Common/InferenceParams.cs View File

@@ -1,6 +1,5 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text;


namespace LLama.Common namespace LLama.Common
{ {
@@ -83,7 +82,7 @@ namespace LLama.Common
/// algorithm described in the paper https://arxiv.org/abs/2007.14966. /// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary> /// </summary>
public MiroStateType Mirostat { get; set; } = MiroStateType.Disable;
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary> /// <summary>
/// target entropy /// target entropy
/// </summary> /// </summary>
@@ -98,10 +97,25 @@ namespace LLama.Common
public bool PenalizeNL { get; set; } = true; public bool PenalizeNL { get; set; } = true;
} }


public enum MiroStateType
/// <summary>
/// Type of "mirostat" sampling to use.
/// https://github.com/basusourya/mirostat
/// </summary>
public enum MirostatType
{ {
/// <summary>
/// Disable Mirostat sampling
/// </summary>
Disable = 0, Disable = 0,
MiroState = 1,
MiroState2 = 2

/// <summary>
/// Original mirostat algorithm
/// </summary>
Mirostat = 1,

/// <summary>
/// Mirostat 2.0 algorithm
/// </summary>
Mirostat2 = 2
} }
} }

+ 3
- 3
LLama/LLamaModel.cs View File

@@ -229,7 +229,7 @@ namespace LLama
/// <param name="tfsZ"></param> /// <param name="tfsZ"></param>
/// <param name="typicalP"></param> /// <param name="typicalP"></param>
/// <returns></returns> /// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{ {
llama_token id = 0; llama_token id = 0;
@@ -240,14 +240,14 @@ namespace LLama
} }
else else
{ {
if (mirostat == MiroStateType.MiroState)
if (mirostat == MirostatType.Mirostat)
{ {
float mirostat_mu = 2.0f * mirostatTau; float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100; const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
} }
else if (mirostat == MiroStateType.MiroState2)
else if (mirostat == MirostatType.Mirostat2)
{ {
float mirostat_mu = 2.0f * mirostatTau; float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);


Loading…
Cancel
Save