diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index c7cb55fe..9e4292ea 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -17,7 +17,8 @@ namespace LLama /// /// The LLama executor for instruct mode. /// - public class InstructExecutor : StatefulExecutorBase + public class InstructExecutor + : StatefulExecutorBase { private bool _is_prompt_run = true; private readonly string _instructionPrefix; @@ -28,11 +29,14 @@ namespace LLama /// /// /// - /// /// /// - public InstructExecutor(LLamaContext context, ILogger logger = null!, string instructionPrefix = "\n\n### Instruction:\n\n", - string instructionSuffix = "\n\n### Response:\n\n") : base(context) + /// + public InstructExecutor(LLamaContext context, + string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n", + ILogger? logger = null) + : base(context, logger) { _inp_pfx = Context.Tokenize(instructionPrefix, true); _inp_sfx = Context.Tokenize(instructionSuffix, false); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 8247ca10..d3d4a9e3 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -27,7 +27,8 @@ namespace LLama /// /// /// - public InteractiveExecutor(LLamaContext context) : base(context) + public InteractiveExecutor(LLamaContext context, ILogger? logger = null) + : base(context, logger) { _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index d1b73c2f..80488b71 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -21,9 +21,9 @@ namespace LLama public class StatelessExecutor : ILLamaExecutor { - private readonly ILogger? _logger; private readonly LLamaWeights _weights; private readonly IContextParams _params; + private readonly ILogger? _logger; /// /// The context used by the executor when running the inference. @@ -36,24 +36,25 @@ namespace LLama /// /// /// - public StatelessExecutor(LLamaWeights weights, IContextParams @params) + public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { _weights = weights; _params = @params; + _logger = logger; - Context = _weights.CreateContext(_params); + Context = _weights.CreateContext(_params, logger); Context.Dispose(); } /// public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - using var context = _weights.CreateContext(_params); + using var context = _weights.CreateContext(_params, _logger); Context = context; if (!Context.NativeHandle.IsClosed) Context.Dispose(); - Context = _weights.CreateContext(Context.Params); + Context = _weights.CreateContext(Context.Params, _logger); if (inferenceParams != null) { diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 5dc2024d..64878e2a 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -81,10 +81,11 @@ namespace LLama /// Create a llama_context using this model /// /// + /// /// - public LLamaContext CreateContext(IContextParams @params) + public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null) { - return new LLamaContext(this, @params); + return new LLamaContext(this, @params, logger); } } }