Browse Source

Refactor conflicting object name SessionOptions

tags/v0.6.0
sa_ddam213 2 years ago
parent
commit
e2a17d6b6f
9 changed files with 43 additions and 30 deletions
  1. +13
    -0
      LLama.Web/Common/ISessionConfig.cs
  2. +1
    -1
      LLama.Web/Common/SessionConfig.cs
  3. +3
    -3
      LLama.Web/Extensions.cs
  4. +1
    -1
      LLama.Web/Hubs/SessionConnectionHub.cs
  5. +15
    -15
      LLama.Web/Models/ModelSession.cs
  6. +5
    -5
      LLama.Web/Pages/Index.cshtml
  7. +2
    -2
      LLama.Web/Pages/Index.cshtml.cs
  8. +2
    -2
      LLama.Web/Services/IModelSessionService.cs
  9. +1
    -1
      LLama.Web/Services/ModelSessionService.cs

+ 13
- 0
LLama.Web/Common/ISessionConfig.cs View File

@@ -0,0 +1,13 @@
namespace LLama.Web.Common
{
public interface ISessionConfig
{
string AntiPrompt { get; set; }
List<string> AntiPrompts { get; set; }
LLamaExecutorType ExecutorType { get; set; }
string Model { get; set; }
string OutputFilter { get; set; }
List<string> OutputFilters { get; set; }
string Prompt { get; set; }
}
}

LLama.Web/Common/SessionOptions.cs → LLama.Web/Common/SessionConfig.cs View File

@@ -1,6 +1,6 @@
namespace LLama.Web.Common
{
public class SessionOptions
public class SessionConfig : ISessionConfig
{
public string Model { get; set; }
public string Prompt { get; set; }

LLama.Web/Extensioms.cs → LLama.Web/Extensions.cs View File

@@ -2,14 +2,14 @@

namespace LLama.Web
{
public static class Extensioms
public static class Extensions
{
/// <summary>
/// Combines the AntiPrompts list and AntiPrompt csv
/// </summary>
/// <param name="sessionConfig">The session configuration.</param>
/// <returns>Combined AntiPrompts with duplicates removed</returns>
public static List<string> GetAntiPrompts(this Common.SessionOptions sessionConfig)
public static List<string> GetAntiPrompts(this ISessionConfig sessionConfig)
{
return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt);
}
@@ -19,7 +19,7 @@ namespace LLama.Web
/// </summary>
/// <param name="sessionConfig">The session configuration.</param>
/// <returns>Combined OutputFilters with duplicates removed</returns>
public static List<string> GetOutputFilters(this Common.SessionOptions sessionConfig)
public static List<string> GetOutputFilters(this ISessionConfig sessionConfig)
{
return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter);
}

+ 1
- 1
LLama.Web/Hubs/SessionConnectionHub.cs View File

@@ -37,7 +37,7 @@ namespace LLama.Web.Hubs


[HubMethodName("LoadModel")]
public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig)
public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig)
{
_logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
await _modelSessionService.CloseAsync(Context.ConnectionId);


+ 15
- 15
LLama.Web/Models/ModelSession.cs View File

@@ -9,21 +9,21 @@ namespace LLama.Web.Models
private readonly LLamaModel _model;
private readonly LLamaContext _context;
private readonly ILLamaExecutor _executor;
private readonly Common.SessionOptions _sessionParams;
private readonly ISessionConfig _sessionConfig;
private readonly ITextStreamTransform _outputTransform;
private readonly InferenceOptions _defaultInferenceConfig;

private CancellationTokenSource _cancellationTokenSource;

public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null)
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null)
{
_model = model;
_context = context;
_sessionId = sessionId;
_sessionParams = sessionOptions;
_sessionConfig = sessionConfig;
_defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
_outputTransform = CreateOutputFilter(_sessionParams);
_executor = CreateExecutor(_model, _context, _sessionParams);
_outputTransform = CreateOutputFilter();
_executor = CreateExecutor();
}

/// <summary>
@@ -34,7 +34,7 @@ namespace LLama.Web.Models
/// <summary>
/// Gets the name of the model.
/// </summary>
public string ModelName => _sessionParams.Model;
public string ModelName => _sessionConfig.Model;

/// <summary>
/// Gets the context.
@@ -44,7 +44,7 @@ namespace LLama.Web.Models
/// <summary>
/// Gets the session configuration.
/// </summary>
public Common.SessionOptions SessionConfig => _sessionParams;
public ISessionConfig SessionConfig => _sessionConfig;

/// <summary>
/// Gets the inference parameters.
@@ -60,16 +60,16 @@ namespace LLama.Web.Models
/// <param name="cancellationToken">The cancellation token.</param>
internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless)
if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless)
return;

if (string.IsNullOrEmpty(_sessionParams.Prompt))
if (string.IsNullOrEmpty(_sessionConfig.Prompt))
return;

// Run Initial prompt
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token))
await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token))
{
// We dont really need the response of the initial prompt, so exit on first token
break;
@@ -114,13 +114,13 @@ namespace LLama.Web.Models
private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
{
var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts();
inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts();
return inferenceParams;
}

private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig)
private ITextStreamTransform CreateOutputFilter()
{
var outputFilters = sessionConfig.GetOutputFilters();
var outputFilters = _sessionConfig.GetOutputFilters();
if (outputFilters.Count > 0)
return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);

@@ -128,9 +128,9 @@ namespace LLama.Web.Models
}


private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig)
private ILLamaExecutor CreateExecutor()
{
return sessionConfig.ExecutorType switch
return _sessionConfig.ExecutorType switch
{
LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
LLamaExecutorType.Instruct => new InstructExecutor(_context),


+ 5
- 5
LLama.Web/Pages/Index.cshtml View File

@@ -24,11 +24,11 @@
<div class="d-flex flex-column m-1">
<div class="d-flex flex-column mb-2">
<small>Model</small>
@Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
@Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
</div>
<div class="d-flex flex-column mb-2">
<small>Inference Type</small>
@Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
@Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
</div>
<nav>
<div class="nav nav-tabs" id="nav-tab" role="tablist">
@@ -40,17 +40,17 @@
<div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab">
<div class="d-flex flex-column mb-2">
<small>Prompt</small>
@Html.TextAreaFor(m => Model.SessionOptions.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
@Html.TextAreaFor(m => Model.SessionConfig.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
</div>

<div class="d-flex flex-column mb-2">
<small>AntiPrompts</small>
@Html.TextBoxFor(m => Model.SessionOptions.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
@Html.TextBoxFor(m => Model.SessionConfig.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
</div>

<div class="d-flex flex-column mb-2">
<small>OutputFilter</small>
@Html.TextBoxFor(m => Model.SessionOptions.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
@Html.TextBoxFor(m => Model.SessionConfig.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
</div>
</div>
<div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab">


+ 2
- 2
LLama.Web/Pages/Index.cshtml.cs View File

@@ -18,14 +18,14 @@ namespace LLama.Web.Pages
public LLamaOptions Options { get; set; }

[BindProperty]
public Common.SessionOptions SessionOptions { get; set; }
public ISessionConfig SessionConfig { get; set; }

[BindProperty]
public InferenceOptions InferenceOptions { get; set; }

public void OnGet()
{
SessionOptions = new Common.SessionOptions
SessionConfig = new SessionConfig
{
Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
AntiPrompt = "User:",


+ 2
- 2
LLama.Web/Services/IModelSessionService.cs View File

@@ -24,7 +24,7 @@ namespace LLama.Web.Services
/// Creates a new ModelSession
/// </summary>
/// <param name="sessionId">The session identifier.</param>
/// <param name="sessionOptions">The session configuration.</param>
/// <param name="sessionConfig">The session configuration.</param>
/// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
@@ -33,7 +33,7 @@ namespace LLama.Web.Services
/// or
/// Failed to create model session
/// </exception>
Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);
Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);


/// <summary>


+ 1
- 1
LLama.Web/Services/ModelSessionService.cs View File

@@ -65,7 +65,7 @@ namespace LLama.Web.Services
/// or
/// Failed to create model session
/// </exception>
public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
public async Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_modelSessions.TryGetValue(sessionId, out _))
throw new Exception($"Session with id {sessionId} already exists");


Loading…
Cancel
Save