Browse Source

update DefaultInferenceParams in WithLLamaSharpDefaults

tags/v0.8.1
xbotter 2 years ago
parent
commit
286904920b
No known key found for this signature in database GPG Key ID: D299220A7FE5CF1E
4 changed files with 19 additions and 12 deletions
  1. +7
    -1
      LLama.Examples/Examples/KernelMemory.cs
  2. +2
    -1
      LLama.Examples/Examples/Runner.cs
  3. +1
    -1
      LLama.KernelMemory/BuilderExtensions.cs
  4. +9
    -9
      LLama.KernelMemory/LlamaSharpTextGeneration.cs

+ 7
- 1
LLama.Examples/Examples/KernelMemory.cs View File

@@ -17,7 +17,13 @@ namespace LLama.Examples.Examples
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
var memory = new KernelMemoryBuilder() var memory = new KernelMemoryBuilder()
.WithLLamaSharpDefaults(new LLamaSharpConfig(modelPath))
.WithLLamaSharpDefaults(new LLamaSharpConfig(modelPath)
{
DefaultInferenceParams = new Common.InferenceParams
{
AntiPrompts = new List<string> { "\n\n" }
}
})
.With(new TextPartitioningOptions .With(new TextPartitioningOptions
{ {
MaxTokensPerParagraph = 300, MaxTokensPerParagraph = 300,


+ 2
- 1
LLama.Examples/Examples/Runner.cs View File

@@ -42,7 +42,8 @@ public class Runner
AnsiConsole.Write(new Rule(choice)); AnsiConsole.Write(new Rule(choice));
await example(); await example();
} }

Console.WriteLine("Press any key to continue...");
Console.ReadKey();
AnsiConsole.Clear(); AnsiConsole.Clear();
} }
} }


+ 1
- 1
LLama.KernelMemory/BuilderExtensions.cs View File

@@ -82,7 +82,7 @@ namespace LLamaSharp.KernelMemory
var executor = new StatelessExecutor(weights, parameters); var executor = new StatelessExecutor(weights, parameters);
var embedder = new LLamaEmbedder(weights, parameters); var embedder = new LLamaEmbedder(weights, parameters);
builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGeneration(embedder)); builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGeneration(embedder));
builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor));
builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor, config?.DefaultInferenceParams));
return builder; return builder;
} }
} }


+ 9
- 9
LLama.KernelMemory/LlamaSharpTextGeneration.cs View File

@@ -15,10 +15,10 @@ namespace LLamaSharp.KernelMemory
/// </summary> /// </summary>
public class LlamaSharpTextGeneration : ITextGeneration, IDisposable public class LlamaSharpTextGeneration : ITextGeneration, IDisposable
{ {
private readonly LLamaSharpConfig? _config;
private readonly LLamaWeights _weights; private readonly LLamaWeights _weights;
private readonly StatelessExecutor _executor; private readonly StatelessExecutor _executor;
private readonly LLamaContext _context; private readonly LLamaContext _context;
private readonly InferenceParams? _defaultInferenceParams;
private bool _ownsContext = false; private bool _ownsContext = false;
private bool _ownsWeights = false; private bool _ownsWeights = false;


@@ -28,7 +28,6 @@ namespace LLamaSharp.KernelMemory
/// <param name="config">The configuration for LLamaSharp.</param> /// <param name="config">The configuration for LLamaSharp.</param>
public LlamaSharpTextGeneration(LLamaSharpConfig config) public LlamaSharpTextGeneration(LLamaSharpConfig config)
{ {
this._config = config;
var parameters = new ModelParams(config.ModelPath) var parameters = new ModelParams(config.ModelPath)
{ {
ContextSize = config?.ContextSize ?? 2048, ContextSize = config?.ContextSize ?? 2048,
@@ -38,6 +37,7 @@ namespace LLamaSharp.KernelMemory
_weights = LLamaWeights.LoadFromFile(parameters); _weights = LLamaWeights.LoadFromFile(parameters);
_context = _weights.CreateContext(parameters); _context = _weights.CreateContext(parameters);
_executor = new StatelessExecutor(_weights, parameters); _executor = new StatelessExecutor(_weights, parameters);
_defaultInferenceParams = config?.DefaultInferenceParams;
_ownsWeights = _ownsContext = true; _ownsWeights = _ownsContext = true;
} }


@@ -48,12 +48,12 @@ namespace LLamaSharp.KernelMemory
/// <param name="weights">A LLamaWeights object.</param> /// <param name="weights">A LLamaWeights object.</param>
/// <param name="context">A LLamaContext object.</param> /// <param name="context">A LLamaContext object.</param>
/// <param name="executor">An executor. Currently only StatelessExecutor is expected.</param> /// <param name="executor">An executor. Currently only StatelessExecutor is expected.</param>
public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null)
public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null, InferenceParams? inferenceParams = null)
{ {
_config = null;
_weights = weights; _weights = weights;
_context = context; _context = context;
_executor = executor ?? new StatelessExecutor(_weights, _context.Params); _executor = executor ?? new StatelessExecutor(_weights, _context.Params);
_defaultInferenceParams = inferenceParams;
} }


/// <inheritdoc/> /// <inheritdoc/>
@@ -72,7 +72,7 @@ namespace LLamaSharp.KernelMemory
/// <inheritdoc/> /// <inheritdoc/>
public IAsyncEnumerable<string> GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default) public IAsyncEnumerable<string> GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default)
{ {
return _executor.InferAsync(prompt, OptionsToParams(options, this._config?.DefaultInferenceParams), cancellationToken: cancellationToken);
return _executor.InferAsync(prompt, OptionsToParams(options, this._defaultInferenceParams), cancellationToken: cancellationToken);
} }


private static InferenceParams OptionsToParams(TextGenerationOptions options, InferenceParams? defaultParams) private static InferenceParams OptionsToParams(TextGenerationOptions options, InferenceParams? defaultParams)
@@ -82,11 +82,11 @@ namespace LLamaSharp.KernelMemory
return defaultParams with return defaultParams with
{ {
AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(), AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(),
Temperature = options.Temperature == default ? defaultParams.Temperature : default,
Temperature = options.Temperature == defaultParams.Temperature ? defaultParams.Temperature : (float)options.Temperature,
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens, MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
FrequencyPenalty = options.FrequencyPenalty == default ? defaultParams.FrequencyPenalty : default,
PresencePenalty = options.PresencePenalty == default ? defaultParams.PresencePenalty : default,
TopP = options.TopP == default ? defaultParams.TopP : default
FrequencyPenalty = options.FrequencyPenalty == defaultParams.FrequencyPenalty ? defaultParams.FrequencyPenalty : (float)options.FrequencyPenalty,
PresencePenalty = options.PresencePenalty == defaultParams.PresencePenalty ? defaultParams.PresencePenalty : (float)options.PresencePenalty,
TopP = options.TopP == defaultParams.TopP ? defaultParams.TopP : (float)options.TopP
}; };
} }
else else


Loading…
Cancel
Save