| @@ -31,7 +31,7 @@ | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InstructExecutor(context, null!, InstructionPrefix, InstructionSuffix); | |||
| var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions." + | |||
| @@ -58,7 +58,7 @@ namespace LLama.Web.Models | |||
| if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances) | |||
| throw new Exception($"Maximum model instances reached"); | |||
| context = _weights.CreateContext(_config); | |||
| context = _weights.CreateContext(_config, _llamaLogger); | |||
| if (_contexts.TryAdd(contextName, context)) | |||
| return Task.FromResult(context); | |||
| @@ -2,6 +2,7 @@ | |||
| using System; | |||
| using LLama.Exceptions; | |||
| using LLama.Abstractions; | |||
| using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| { | |||
| @@ -22,9 +23,10 @@ namespace LLama | |||
| /// Create a new embedder (loading temporary weights) | |||
| /// </summary> | |||
| /// <param name="allParams"></param> | |||
| /// <param name="logger"></param> | |||
| [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] | |||
| public LLamaEmbedder(ILLamaParams allParams) | |||
| : this(allParams, allParams) | |||
| public LLamaEmbedder(ILLamaParams allParams, ILogger? logger = null) | |||
| : this(allParams, allParams, logger) | |||
| { | |||
| } | |||
| @@ -33,13 +35,14 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="modelParams"></param> | |||
| /// <param name="contextParams"></param> | |||
| /// <param name="logger"></param> | |||
| [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] | |||
| public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) | |||
| public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams, ILogger? logger = null) | |||
| { | |||
| using var weights = LLamaWeights.LoadFromFile(modelParams); | |||
| contextParams.EmbeddingMode = true; | |||
| _ctx = weights.CreateContext(contextParams); | |||
| _ctx = weights.CreateContext(contextParams, logger); | |||
| } | |||
| /// <summary> | |||
| @@ -47,10 +50,11 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="weights"></param> | |||
| /// <param name="params"></param> | |||
| public LLamaEmbedder(LLamaWeights weights, IContextParams @params) | |||
| /// <param name="logger"></param> | |||
| public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) | |||
| { | |||
| @params.EmbeddingMode = true; | |||
| _ctx = weights.CreateContext(@params); | |||
| _ctx = weights.CreateContext(@params, logger); | |||
| } | |||
| /// <summary> | |||
| @@ -89,7 +93,6 @@ namespace LLama | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public float[] GetEmbeddings(string text, bool addBos) | |||
| { | |||
| var embed_inp_array = _ctx.Tokenize(text, addBos); | |||
| // TODO(Rinne): deal with log of prompt | |||
| @@ -75,8 +75,9 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="logger"></param> | |||
| protected StatefulExecutorBase(LLamaContext context) | |||
| protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) | |||
| { | |||
| _logger = logger; | |||
| Context = context; | |||
| _pastTokensCount = 0; | |||
| _consumedTokensCount = 0; | |||
| @@ -17,7 +17,8 @@ namespace LLama | |||
| /// <summary> | |||
| /// The LLama executor for instruct mode. | |||
| /// </summary> | |||
| public class InstructExecutor : StatefulExecutorBase | |||
| public class InstructExecutor | |||
| : StatefulExecutorBase | |||
| { | |||
| private bool _is_prompt_run = true; | |||
| private readonly string _instructionPrefix; | |||
| @@ -28,11 +29,14 @@ namespace LLama | |||
| /// | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="logger"></param> | |||
| /// <param name="instructionPrefix"></param> | |||
| /// <param name="instructionSuffix"></param> | |||
| public InstructExecutor(LLamaContext context, ILogger logger = null!, string instructionPrefix = "\n\n### Instruction:\n\n", | |||
| string instructionSuffix = "\n\n### Response:\n\n") : base(context) | |||
| /// <param name="logger"></param> | |||
| public InstructExecutor(LLamaContext context, | |||
| string instructionPrefix = "\n\n### Instruction:\n\n", | |||
| string instructionSuffix = "\n\n### Response:\n\n", | |||
| ILogger? logger = null) | |||
| : base(context, logger) | |||
| { | |||
| _inp_pfx = Context.Tokenize(instructionPrefix, true); | |||
| _inp_sfx = Context.Tokenize(instructionSuffix, false); | |||
| @@ -27,7 +27,8 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="logger"></param> | |||
| public InteractiveExecutor(LLamaContext context) : base(context) | |||
| public InteractiveExecutor(LLamaContext context, ILogger? logger = null) | |||
| : base(context, logger) | |||
| { | |||
| _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); | |||
| } | |||
| @@ -21,9 +21,9 @@ namespace LLama | |||
| public class StatelessExecutor | |||
| : ILLamaExecutor | |||
| { | |||
| private readonly ILogger? _logger; | |||
| private readonly LLamaWeights _weights; | |||
| private readonly IContextParams _params; | |||
| private readonly ILogger? _logger; | |||
| /// <summary> | |||
| /// The context used by the executor when running the inference. | |||
| @@ -36,24 +36,25 @@ namespace LLama | |||
| /// <param name="weights"></param> | |||
| /// <param name="params"></param> | |||
| /// <param name="logger"></param> | |||
| public StatelessExecutor(LLamaWeights weights, IContextParams @params) | |||
| public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) | |||
| { | |||
| _weights = weights; | |||
| _params = @params; | |||
| _logger = logger; | |||
| Context = _weights.CreateContext(_params); | |||
| Context = _weights.CreateContext(_params, logger); | |||
| Context.Dispose(); | |||
| } | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| using var context = _weights.CreateContext(_params); | |||
| using var context = _weights.CreateContext(_params, _logger); | |||
| Context = context; | |||
| if (!Context.NativeHandle.IsClosed) | |||
| Context.Dispose(); | |||
| Context = _weights.CreateContext(Context.Params); | |||
| Context = _weights.CreateContext(Context.Params, _logger); | |||
| if (inferenceParams != null) | |||
| { | |||
| @@ -81,10 +81,11 @@ namespace LLama | |||
| /// Create a llama_context using this model | |||
| /// </summary> | |||
| /// <param name="params"></param> | |||
| /// <param name="logger"></param> | |||
| /// <returns></returns> | |||
| public LLamaContext CreateContext(IContextParams @params) | |||
| public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null) | |||
| { | |||
| return new LLamaContext(this, @params); | |||
| return new LLamaContext(this, @params, logger); | |||
| } | |||
| } | |||
| } | |||