Browse Source

Refactor conflicting object name SessionOptions

tags/v0.6.0
sa_ddam213 2 years ago
parent
commit
e2a17d6b6f
9 changed files with 43 additions and 30 deletions
  1. +13
    -0
      LLama.Web/Common/ISessionConfig.cs
  2. +1
    -1
      LLama.Web/Common/SessionConfig.cs
  3. +3
    -3
      LLama.Web/Extensions.cs
  4. +1
    -1
      LLama.Web/Hubs/SessionConnectionHub.cs
  5. +15
    -15
      LLama.Web/Models/ModelSession.cs
  6. +5
    -5
      LLama.Web/Pages/Index.cshtml
  7. +2
    -2
      LLama.Web/Pages/Index.cshtml.cs
  8. +2
    -2
      LLama.Web/Services/IModelSessionService.cs
  9. +1
    -1
      LLama.Web/Services/ModelSessionService.cs

+ 13
- 0
LLama.Web/Common/ISessionConfig.cs View File

@@ -0,0 +1,13 @@
namespace LLama.Web.Common
{
public interface ISessionConfig
{
string AntiPrompt { get; set; }
List<string> AntiPrompts { get; set; }
LLamaExecutorType ExecutorType { get; set; }
string Model { get; set; }
string OutputFilter { get; set; }
List<string> OutputFilters { get; set; }
string Prompt { get; set; }
}
}

LLama.Web/Common/SessionOptions.cs → LLama.Web/Common/SessionConfig.cs View File

@@ -1,6 +1,6 @@
namespace LLama.Web.Common namespace LLama.Web.Common
{ {
public class SessionOptions
public class SessionConfig : ISessionConfig
{ {
public string Model { get; set; } public string Model { get; set; }
public string Prompt { get; set; } public string Prompt { get; set; }

LLama.Web/Extensioms.cs → LLama.Web/Extensions.cs View File

@@ -2,14 +2,14 @@


namespace LLama.Web namespace LLama.Web
{ {
public static class Extensioms
public static class Extensions
{ {
/// <summary> /// <summary>
/// Combines the AntiPrompts list and AntiPrompt csv /// Combines the AntiPrompts list and AntiPrompt csv
/// </summary> /// </summary>
/// <param name="sessionConfig">The session configuration.</param> /// <param name="sessionConfig">The session configuration.</param>
/// <returns>Combined AntiPrompts with duplicates removed</returns> /// <returns>Combined AntiPrompts with duplicates removed</returns>
public static List<string> GetAntiPrompts(this Common.SessionOptions sessionConfig)
public static List<string> GetAntiPrompts(this ISessionConfig sessionConfig)
{ {
return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt);
} }
@@ -19,7 +19,7 @@ namespace LLama.Web
/// </summary> /// </summary>
/// <param name="sessionConfig">The session configuration.</param> /// <param name="sessionConfig">The session configuration.</param>
/// <returns>Combined OutputFilters with duplicates removed</returns> /// <returns>Combined OutputFilters with duplicates removed</returns>
public static List<string> GetOutputFilters(this Common.SessionOptions sessionConfig)
public static List<string> GetOutputFilters(this ISessionConfig sessionConfig)
{ {
return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter);
} }

+ 1
- 1
LLama.Web/Hubs/SessionConnectionHub.cs View File

@@ -37,7 +37,7 @@ namespace LLama.Web.Hubs




[HubMethodName("LoadModel")] [HubMethodName("LoadModel")]
public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig)
public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig)
{ {
_logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
await _modelSessionService.CloseAsync(Context.ConnectionId); await _modelSessionService.CloseAsync(Context.ConnectionId);


+ 15
- 15
LLama.Web/Models/ModelSession.cs View File

@@ -9,21 +9,21 @@ namespace LLama.Web.Models
private readonly LLamaModel _model; private readonly LLamaModel _model;
private readonly LLamaContext _context; private readonly LLamaContext _context;
private readonly ILLamaExecutor _executor; private readonly ILLamaExecutor _executor;
private readonly Common.SessionOptions _sessionParams;
private readonly ISessionConfig _sessionConfig;
private readonly ITextStreamTransform _outputTransform; private readonly ITextStreamTransform _outputTransform;
private readonly InferenceOptions _defaultInferenceConfig; private readonly InferenceOptions _defaultInferenceConfig;


private CancellationTokenSource _cancellationTokenSource; private CancellationTokenSource _cancellationTokenSource;


public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null)
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null)
{ {
_model = model; _model = model;
_context = context; _context = context;
_sessionId = sessionId; _sessionId = sessionId;
_sessionParams = sessionOptions;
_sessionConfig = sessionConfig;
_defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
_outputTransform = CreateOutputFilter(_sessionParams);
_executor = CreateExecutor(_model, _context, _sessionParams);
_outputTransform = CreateOutputFilter();
_executor = CreateExecutor();
} }


/// <summary> /// <summary>
@@ -34,7 +34,7 @@ namespace LLama.Web.Models
/// <summary> /// <summary>
/// Gets the name of the model. /// Gets the name of the model.
/// </summary> /// </summary>
public string ModelName => _sessionParams.Model;
public string ModelName => _sessionConfig.Model;


/// <summary> /// <summary>
/// Gets the context. /// Gets the context.
@@ -44,7 +44,7 @@ namespace LLama.Web.Models
/// <summary> /// <summary>
/// Gets the session configuration. /// Gets the session configuration.
/// </summary> /// </summary>
public Common.SessionOptions SessionConfig => _sessionParams;
public ISessionConfig SessionConfig => _sessionConfig;


/// <summary> /// <summary>
/// Gets the inference parameters. /// Gets the inference parameters.
@@ -60,16 +60,16 @@ namespace LLama.Web.Models
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{ {
if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless)
if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless)
return; return;


if (string.IsNullOrEmpty(_sessionParams.Prompt))
if (string.IsNullOrEmpty(_sessionConfig.Prompt))
return; return;


// Run Initial prompt // Run Initial prompt
var inferenceParams = ConfigureInferenceParams(inferenceConfig); var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token))
await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token))
{ {
// We dont really need the response of the initial prompt, so exit on first token // We dont really need the response of the initial prompt, so exit on first token
break; break;
@@ -114,13 +114,13 @@ namespace LLama.Web.Models
private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
{ {
var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts();
inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts();
return inferenceParams; return inferenceParams;
} }


private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig)
private ITextStreamTransform CreateOutputFilter()
{ {
var outputFilters = sessionConfig.GetOutputFilters();
var outputFilters = _sessionConfig.GetOutputFilters();
if (outputFilters.Count > 0) if (outputFilters.Count > 0)
return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);


@@ -128,9 +128,9 @@ namespace LLama.Web.Models
} }




private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig)
private ILLamaExecutor CreateExecutor()
{ {
return sessionConfig.ExecutorType switch
return _sessionConfig.ExecutorType switch
{ {
LLamaExecutorType.Interactive => new InteractiveExecutor(_context), LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
LLamaExecutorType.Instruct => new InstructExecutor(_context), LLamaExecutorType.Instruct => new InstructExecutor(_context),


+ 5
- 5
LLama.Web/Pages/Index.cshtml View File

@@ -24,11 +24,11 @@
<div class="d-flex flex-column m-1"> <div class="d-flex flex-column m-1">
<div class="d-flex flex-column mb-2"> <div class="d-flex flex-column mb-2">
<small>Model</small> <small>Model</small>
@Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
@Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
</div> </div>
<div class="d-flex flex-column mb-2"> <div class="d-flex flex-column mb-2">
<small>Inference Type</small> <small>Inference Type</small>
@Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
@Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
</div> </div>
<nav> <nav>
<div class="nav nav-tabs" id="nav-tab" role="tablist"> <div class="nav nav-tabs" id="nav-tab" role="tablist">
@@ -40,17 +40,17 @@
<div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab"> <div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab">
<div class="d-flex flex-column mb-2"> <div class="d-flex flex-column mb-2">
<small>Prompt</small> <small>Prompt</small>
@Html.TextAreaFor(m => Model.SessionOptions.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
@Html.TextAreaFor(m => Model.SessionConfig.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
</div> </div>


<div class="d-flex flex-column mb-2"> <div class="d-flex flex-column mb-2">
<small>AntiPrompts</small> <small>AntiPrompts</small>
@Html.TextBoxFor(m => Model.SessionOptions.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
@Html.TextBoxFor(m => Model.SessionConfig.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
</div> </div>


<div class="d-flex flex-column mb-2"> <div class="d-flex flex-column mb-2">
<small>OutputFilter</small> <small>OutputFilter</small>
@Html.TextBoxFor(m => Model.SessionOptions.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
@Html.TextBoxFor(m => Model.SessionConfig.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
</div> </div>
</div> </div>
<div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab"> <div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab">


+ 2
- 2
LLama.Web/Pages/Index.cshtml.cs View File

@@ -18,14 +18,14 @@ namespace LLama.Web.Pages
public LLamaOptions Options { get; set; } public LLamaOptions Options { get; set; }


[BindProperty] [BindProperty]
public Common.SessionOptions SessionOptions { get; set; }
public ISessionConfig SessionConfig { get; set; }


[BindProperty] [BindProperty]
public InferenceOptions InferenceOptions { get; set; } public InferenceOptions InferenceOptions { get; set; }


public void OnGet() public void OnGet()
{ {
SessionOptions = new Common.SessionOptions
SessionConfig = new SessionConfig
{ {
Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
AntiPrompt = "User:", AntiPrompt = "User:",


+ 2
- 2
LLama.Web/Services/IModelSessionService.cs View File

@@ -24,7 +24,7 @@ namespace LLama.Web.Services
/// Creates a new ModelSession /// Creates a new ModelSession
/// </summary> /// </summary>
/// <param name="sessionId">The session identifier.</param> /// <param name="sessionId">The session identifier.</param>
/// <param name="sessionOptions">The session configuration.</param>
/// <param name="sessionConfig">The session configuration.</param>
/// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param> /// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns> /// <returns></returns>
@@ -33,7 +33,7 @@ namespace LLama.Web.Services
/// or /// or
/// Failed to create model session /// Failed to create model session
/// </exception> /// </exception>
Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);
Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);




/// <summary> /// <summary>


+ 1
- 1
LLama.Web/Services/ModelSessionService.cs View File

@@ -65,7 +65,7 @@ namespace LLama.Web.Services
/// or /// or
/// Failed to create model session /// Failed to create model session
/// </exception> /// </exception>
public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
public async Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{ {
if (_modelSessions.TryGetValue(sessionId, out _)) if (_modelSessions.TryGetValue(sessionId, out _))
throw new Exception($"Session with id {sessionId} already exists"); throw new Exception($"Session with id {sessionId} already exists");


Loading…
Cancel
Save