| @@ -1,9 +1,99 @@ | |||
| using LLama.Common; | |||
| using LLama.Abstractions; | |||
| namespace LLama.Web.Common | |||
| { | |||
| public class ParameterOptions : InferenceParams | |||
| { | |||
| public class ParameterOptions : IInferenceParams | |||
| { | |||
| public string Name { get; set; } | |||
| } | |||
| /// <summary> | |||
| /// number of tokens to keep from initial prompt | |||
| /// </summary> | |||
| public int TokensKeep { get; set; } = 0; | |||
| /// <summary> | |||
| /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response | |||
| /// until it complete. | |||
| /// </summary> | |||
| public int MaxTokens { get; set; } = -1; | |||
| /// <summary> | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| public Dictionary<int, float>? LogitBias { get; set; } = null; | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| /// path to file for saving/loading model eval state | |||
| /// </summary> | |||
| public string PathSession { get; set; } = string.Empty; | |||
| /// <summary> | |||
| /// string to suffix user inputs with | |||
| /// </summary> | |||
| public string InputSuffix { get; set; } = string.Empty; | |||
| /// <summary> | |||
| /// string to prefix user inputs with | |||
| /// </summary> | |||
| public string InputPrefix { get; set; } = string.Empty; | |||
| /// <summary> | |||
| /// 0 or lower to use vocab size | |||
| /// </summary> | |||
| public int TopK { get; set; } = 40; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TopP { get; set; } = 0.95f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TfsZ { get; set; } = 1.0f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TypicalP { get; set; } = 1.0f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float Temperature { get; set; } = 0.8f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| /// <summary> | |||
| /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) | |||
| /// </summary> | |||
| public int RepeatLastTokensCount { get; set; } = 64; | |||
| /// <summary> | |||
| /// frequency penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| public float FrequencyPenalty { get; set; } = .0f; | |||
| /// <summary> | |||
| /// presence penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| public float PresencePenalty { get; set; } = .0f; | |||
| /// <summary> | |||
| /// Mirostat uses tokens instead of words. | |||
| /// algorithm described in the paper https://arxiv.org/abs/2007.14966. | |||
| /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 | |||
| /// </summary> | |||
| public MirostatType Mirostat { get; set; } = MirostatType.Disable; | |||
| /// <summary> | |||
| /// target entropy | |||
| /// </summary> | |||
| public float MirostatTau { get; set; } = 5.0f; | |||
| /// <summary> | |||
| /// learning rate | |||
| /// </summary> | |||
| public float MirostatEta { get; set; } = 0.1f; | |||
| /// <summary> | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| public bool PenalizeNL { get; set; } = true; | |||
| } | |||
| } | |||
| @@ -0,0 +1,117 @@ | |||
| using System.Collections.Generic; | |||
| using LLama.Common; | |||
| namespace LLama.Abstractions | |||
| { | |||
| /// <summary> | |||
| /// The paramters used for inference. | |||
| /// </summary> | |||
| public interface IInferenceParams | |||
| { | |||
| /// <summary> | |||
| /// number of tokens to keep from initial prompt | |||
| /// </summary> | |||
| public int TokensKeep { get; set; } | |||
| /// <summary> | |||
| /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response | |||
| /// until it complete. | |||
| /// </summary> | |||
| public int MaxTokens { get; set; } | |||
| /// <summary> | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| public Dictionary<int, float>? LogitBias { get; set; } | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| public IEnumerable<string> AntiPrompts { get; set; } | |||
| /// <summary> | |||
| /// path to file for saving/loading model eval state | |||
| /// </summary> | |||
| public string PathSession { get; set; } | |||
| /// <summary> | |||
| /// string to suffix user inputs with | |||
| /// </summary> | |||
| public string InputSuffix { get; set; } | |||
| /// <summary> | |||
| /// string to prefix user inputs with | |||
| /// </summary> | |||
| public string InputPrefix { get; set; } | |||
| /// <summary> | |||
| /// 0 or lower to use vocab size | |||
| /// </summary> | |||
| public int TopK { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TopP { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TfsZ { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TypicalP { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float Temperature { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float RepeatPenalty { get; set; } | |||
| /// <summary> | |||
| /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) | |||
| /// </summary> | |||
| public int RepeatLastTokensCount { get; set; } | |||
| /// <summary> | |||
| /// frequency penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| public float FrequencyPenalty { get; set; } | |||
| /// <summary> | |||
| /// presence penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| public float PresencePenalty { get; set; } | |||
| /// <summary> | |||
| /// Mirostat uses tokens instead of words. | |||
| /// algorithm described in the paper https://arxiv.org/abs/2007.14966. | |||
| /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 | |||
| /// </summary> | |||
| public MirostatType Mirostat { get; set; } | |||
| /// <summary> | |||
| /// target entropy | |||
| /// </summary> | |||
| public float MirostatTau { get; set; } | |||
| /// <summary> | |||
| /// learning rate | |||
| /// </summary> | |||
| public float MirostatEta { get; set; } | |||
| /// <summary> | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| public bool PenalizeNL { get; set; } | |||
| } | |||
| } | |||
| @@ -23,7 +23,7 @@ namespace LLama.Abstractions | |||
| /// <param name="inferenceParams">Any additional parameters</param> | |||
| /// <param name="token">A cancellation token.</param> | |||
| /// <returns></returns> | |||
| IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken token = default); | |||
| IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); | |||
| /// <summary> | |||
| /// Asynchronously infers a response from the model. | |||
| @@ -32,6 +32,6 @@ namespace LLama.Abstractions | |||
| /// <param name="inferenceParams">Any additional parameters</param> | |||
| /// <param name="token">A cancellation token.</param> | |||
| /// <returns></returns> | |||
| IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, CancellationToken token = default); | |||
| IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); | |||
| } | |||
| } | |||
| @@ -138,7 +138,7 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public IEnumerable<string> Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| var prompt = HistoryTransform.HistoryToText(history); | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||
| @@ -159,7 +159,7 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public IEnumerable<string> Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| foreach(var inputTransform in InputTransformPipeline) | |||
| { | |||
| @@ -182,7 +182,7 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| var prompt = HistoryTransform.HistoryToText(history); | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||
| @@ -202,7 +202,7 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach (var inputTransform in InputTransformPipeline) | |||
| { | |||
| @@ -218,13 +218,13 @@ namespace LLama | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | |||
| } | |||
| private IEnumerable<string> ChatInternal(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| var results = _executor.Infer(prompt, inferenceParams, cancellationToken); | |||
| return OutputTransform.Transform(results); | |||
| } | |||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | |||
| await foreach (var item in OutputTransform.TransformAsync(results)) | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace LLama.Common | |||
| @@ -7,7 +8,7 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// The paramters used for inference. | |||
| /// </summary> | |||
| public class InferenceParams | |||
| public class InferenceParams : IInferenceParams | |||
| { | |||
| /// <summary> | |||
| /// number of tokens to keep from initial prompt | |||
| @@ -231,13 +231,13 @@ namespace LLama | |||
| /// <param name="args"></param> | |||
| /// <param name="extraOutputs"></param> | |||
| /// <returns></returns> | |||
| protected abstract bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs); | |||
| protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs); | |||
| /// <summary> | |||
| /// The core inference logic. | |||
| /// </summary> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="args"></param> | |||
| protected abstract void InferInternal(InferenceParams inferenceParams, InferStateArgs args); | |||
| protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); | |||
| /// <summary> | |||
| /// Save the current state to a file. | |||
| /// </summary> | |||
| @@ -267,7 +267,7 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public virtual IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| cancellationToken.ThrowIfCancellationRequested(); | |||
| if (inferenceParams is null) | |||
| @@ -324,7 +324,7 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach (var result in Infer(text, inferenceParams, cancellationToken)) | |||
| { | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Common; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| @@ -136,7 +137,7 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| { | |||
| extraOutputs = null; | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| @@ -179,7 +180,7 @@ namespace LLama | |||
| return false; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args) | |||
| protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| @@ -1,5 +1,6 @@ | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| @@ -122,7 +123,7 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="args"></param> | |||
| /// <returns></returns> | |||
| protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| { | |||
| extraOutputs = null; | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| @@ -166,7 +167,7 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args) | |||
| protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| @@ -36,7 +36,7 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| public IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| cancellationToken.ThrowIfCancellationRequested(); | |||
| int n_past = 1; | |||
| @@ -123,7 +123,7 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach (var result in Infer(text, inferenceParams, cancellationToken)) | |||
| { | |||