Browse Source

Correctly passing through mu value to mirostate instead of resetting it every time.

tags/v0.4.2-preview
Martin Evans 2 years ago
parent
commit
c64507cb41
5 changed files with 45 additions and 19 deletions
  1. +18
    -2
      LLama/LLamaExecutorBase.cs
  2. +9
    -4
      LLama/LLamaInstructExecutor.cs
  3. +9
    -7
      LLama/LLamaInteractExecutor.cs
  4. +7
    -5
      LLama/LLamaModel.cs
  5. +2
    -1
      LLama/LLamaStatelessExecutor.cs

+ 18
- 2
LLama/LLamaExecutorBase.cs View File

@@ -66,6 +66,12 @@ namespace LLama
/// The mode used by the executor.
/// </summary>
public LLamaModel Model => _model;

/// <summary>
/// Current "mu" value for mirostate sampling
/// </summary>
protected float MirostateMu { get; set; } = float.NaN;

/// <summary>
///
/// </summary>
@@ -78,8 +84,6 @@ namespace LLama
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_embeds = new();
_embed_inps = new();
_last_n_tokens = new FixedSizeQueue<llama_token>(_model.ContextSize).FillWith(0);
}

@@ -359,24 +363,36 @@ namespace LLama
{
[JsonPropertyName("n_past")]
public int PastTokensCount { get; set; }

[JsonPropertyName("n_consumed")]
public int ConsumedTokensCount { get; set; }

[JsonPropertyName("n_session_consumed")]
public int ConsumedSessionCount { get; set; }

[JsonPropertyName("n_matching_session_tokens")]
public int MatchingSessionTokensCount { get; set; }

[JsonPropertyName("path_session")]
public string SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }

[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }

[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }

[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }

[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }

[JsonPropertyName("mirostate_mu")]
public float MirostateMu { get; set; }
}
}
}

+ 9
- 4
LLama/LLamaInstructExecutor.cs View File

@@ -4,7 +4,6 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;

@@ -20,6 +19,7 @@ namespace LLama
string _instructionPrefix;
llama_token[] _inp_pfx;
llama_token[] _inp_sfx;

/// <summary>
///
/// </summary>
@@ -51,7 +51,8 @@ namespace LLama
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
@@ -214,8 +215,12 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;

_last_n_tokens.Enqueue(id);



+ 9
- 7
LLama/LLamaInteractExecutor.cs View File

@@ -4,12 +4,8 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;

namespace LLama
{
@@ -21,6 +17,7 @@ namespace LLama
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;

/// <summary>
///
/// </summary>
@@ -46,7 +43,8 @@ namespace LLama
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
@@ -204,8 +202,12 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;

_last_n_tokens.Enqueue(id);



+ 7
- 5
LLama/LLamaModel.cs View File

@@ -220,6 +220,7 @@ namespace LLama
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="candidates"></param>
/// <param name="mirostat_mu"></param>
/// <param name="temperature"></param>
/// <param name="mirostat"></param>
/// <param name="mirostatTau"></param>
@@ -229,10 +230,10 @@ namespace LLama
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id = 0;
llama_token id;
if (temperature <= 0)
{
// Greedy sampling
@@ -240,16 +241,17 @@ namespace LLama
}
else
{
if (float.IsNaN(mirostat_mu))
mirostat_mu = 2 * mirostatTau;

if (mirostat == MiroStateType.MiroState)
{
float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MiroStateType.MiroState2)
{
float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}


+ 2
- 1
LLama/LLamaStatelessExecutor.cs View File

@@ -57,6 +57,7 @@ namespace LLama
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;

var mu = float.NaN;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
@@ -70,7 +71,7 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);

lastTokens.Add(id);


Loading…
Cancel
Save