| @@ -0,0 +1,13 @@ | |||||
| namespace LLama.Web.Common | |||||
| { | |||||
| public interface ISessionConfig | |||||
| { | |||||
| string AntiPrompt { get; set; } | |||||
| List<string> AntiPrompts { get; set; } | |||||
| LLamaExecutorType ExecutorType { get; set; } | |||||
| string Model { get; set; } | |||||
| string OutputFilter { get; set; } | |||||
| List<string> OutputFilters { get; set; } | |||||
| string Prompt { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,6 @@ | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| public class SessionOptions | |||||
| public class SessionConfig : ISessionConfig | |||||
| { | { | ||||
| public string Model { get; set; } | public string Model { get; set; } | ||||
| public string Prompt { get; set; } | public string Prompt { get; set; } | ||||
| @@ -2,14 +2,14 @@ | |||||
| namespace LLama.Web | namespace LLama.Web | ||||
| { | { | ||||
| public static class Extensioms | |||||
| public static class Extensions | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Combines the AntiPrompts list and AntiPrompt csv | /// Combines the AntiPrompts list and AntiPrompt csv | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="sessionConfig">The session configuration.</param> | /// <param name="sessionConfig">The session configuration.</param> | ||||
| /// <returns>Combined AntiPrompts with duplicates removed</returns> | /// <returns>Combined AntiPrompts with duplicates removed</returns> | ||||
| public static List<string> GetAntiPrompts(this Common.SessionOptions sessionConfig) | |||||
| public static List<string> GetAntiPrompts(this ISessionConfig sessionConfig) | |||||
| { | { | ||||
| return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); | return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); | ||||
| } | } | ||||
| @@ -19,7 +19,7 @@ namespace LLama.Web | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="sessionConfig">The session configuration.</param> | /// <param name="sessionConfig">The session configuration.</param> | ||||
| /// <returns>Combined OutputFilters with duplicates removed</returns> | /// <returns>Combined OutputFilters with duplicates removed</returns> | ||||
| public static List<string> GetOutputFilters(this Common.SessionOptions sessionConfig) | |||||
| public static List<string> GetOutputFilters(this ISessionConfig sessionConfig) | |||||
| { | { | ||||
| return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); | return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ namespace LLama.Web.Hubs | |||||
| [HubMethodName("LoadModel")] | [HubMethodName("LoadModel")] | ||||
| public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig) | |||||
| public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig) | |||||
| { | { | ||||
| _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); | _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); | ||||
| await _modelSessionService.CloseAsync(Context.ConnectionId); | await _modelSessionService.CloseAsync(Context.ConnectionId); | ||||
| @@ -9,21 +9,21 @@ namespace LLama.Web.Models | |||||
| private readonly LLamaModel _model; | private readonly LLamaModel _model; | ||||
| private readonly LLamaContext _context; | private readonly LLamaContext _context; | ||||
| private readonly ILLamaExecutor _executor; | private readonly ILLamaExecutor _executor; | ||||
| private readonly Common.SessionOptions _sessionParams; | |||||
| private readonly ISessionConfig _sessionConfig; | |||||
| private readonly ITextStreamTransform _outputTransform; | private readonly ITextStreamTransform _outputTransform; | ||||
| private readonly InferenceOptions _defaultInferenceConfig; | private readonly InferenceOptions _defaultInferenceConfig; | ||||
| private CancellationTokenSource _cancellationTokenSource; | private CancellationTokenSource _cancellationTokenSource; | ||||
| public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) | |||||
| public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null) | |||||
| { | { | ||||
| _model = model; | _model = model; | ||||
| _context = context; | _context = context; | ||||
| _sessionId = sessionId; | _sessionId = sessionId; | ||||
| _sessionParams = sessionOptions; | |||||
| _sessionConfig = sessionConfig; | |||||
| _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); | _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); | ||||
| _outputTransform = CreateOutputFilter(_sessionParams); | |||||
| _executor = CreateExecutor(_model, _context, _sessionParams); | |||||
| _outputTransform = CreateOutputFilter(); | |||||
| _executor = CreateExecutor(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -34,7 +34,7 @@ namespace LLama.Web.Models | |||||
| /// <summary> | /// <summary> | ||||
| /// Gets the name of the model. | /// Gets the name of the model. | ||||
| /// </summary> | /// </summary> | ||||
| public string ModelName => _sessionParams.Model; | |||||
| public string ModelName => _sessionConfig.Model; | |||||
| /// <summary> | /// <summary> | ||||
| /// Gets the context. | /// Gets the context. | ||||
| @@ -44,7 +44,7 @@ namespace LLama.Web.Models | |||||
| /// <summary> | /// <summary> | ||||
| /// Gets the session configuration. | /// Gets the session configuration. | ||||
| /// </summary> | /// </summary> | ||||
| public Common.SessionOptions SessionConfig => _sessionParams; | |||||
| public ISessionConfig SessionConfig => _sessionConfig; | |||||
| /// <summary> | /// <summary> | ||||
| /// Gets the inference parameters. | /// Gets the inference parameters. | ||||
| @@ -60,16 +60,16 @@ namespace LLama.Web.Models | |||||
| /// <param name="cancellationToken">The cancellation token.</param> | /// <param name="cancellationToken">The cancellation token.</param> | ||||
| internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) | |||||
| if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) | |||||
| return; | return; | ||||
| if (string.IsNullOrEmpty(_sessionParams.Prompt)) | |||||
| if (string.IsNullOrEmpty(_sessionConfig.Prompt)) | |||||
| return; | return; | ||||
| // Run Initial prompt | // Run Initial prompt | ||||
| var inferenceParams = ConfigureInferenceParams(inferenceConfig); | var inferenceParams = ConfigureInferenceParams(inferenceConfig); | ||||
| _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); | _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); | ||||
| await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) | |||||
| await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token)) | |||||
| { | { | ||||
| // We dont really need the response of the initial prompt, so exit on first token | // We dont really need the response of the initial prompt, so exit on first token | ||||
| break; | break; | ||||
| @@ -114,13 +114,13 @@ namespace LLama.Web.Models | |||||
| private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) | private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) | ||||
| { | { | ||||
| var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; | var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; | ||||
| inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); | |||||
| inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts(); | |||||
| return inferenceParams; | return inferenceParams; | ||||
| } | } | ||||
| private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) | |||||
| private ITextStreamTransform CreateOutputFilter() | |||||
| { | { | ||||
| var outputFilters = sessionConfig.GetOutputFilters(); | |||||
| var outputFilters = _sessionConfig.GetOutputFilters(); | |||||
| if (outputFilters.Count > 0) | if (outputFilters.Count > 0) | ||||
| return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); | return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); | ||||
| @@ -128,9 +128,9 @@ namespace LLama.Web.Models | |||||
| } | } | ||||
| private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) | |||||
| private ILLamaExecutor CreateExecutor() | |||||
| { | { | ||||
| return sessionConfig.ExecutorType switch | |||||
| return _sessionConfig.ExecutorType switch | |||||
| { | { | ||||
| LLamaExecutorType.Interactive => new InteractiveExecutor(_context), | LLamaExecutorType.Interactive => new InteractiveExecutor(_context), | ||||
| LLamaExecutorType.Instruct => new InstructExecutor(_context), | LLamaExecutorType.Instruct => new InstructExecutor(_context), | ||||
| @@ -24,11 +24,11 @@ | |||||
| <div class="d-flex flex-column m-1"> | <div class="d-flex flex-column m-1"> | ||||
| <div class="d-flex flex-column mb-2"> | <div class="d-flex flex-column mb-2"> | ||||
| <small>Model</small> | <small>Model</small> | ||||
| @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||||
| @Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||||
| </div> | </div> | ||||
| <div class="d-flex flex-column mb-2"> | <div class="d-flex flex-column mb-2"> | ||||
| <small>Inference Type</small> | <small>Inference Type</small> | ||||
| @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||||
| @Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||||
| </div> | </div> | ||||
| <nav> | <nav> | ||||
| <div class="nav nav-tabs" id="nav-tab" role="tablist"> | <div class="nav nav-tabs" id="nav-tab" role="tablist"> | ||||
| @@ -40,17 +40,17 @@ | |||||
| <div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab"> | <div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab"> | ||||
| <div class="d-flex flex-column mb-2"> | <div class="d-flex flex-column mb-2"> | ||||
| <small>Prompt</small> | <small>Prompt</small> | ||||
| @Html.TextAreaFor(m => Model.SessionOptions.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8}) | |||||
| @Html.TextAreaFor(m => Model.SessionConfig.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8}) | |||||
| </div> | </div> | ||||
| <div class="d-flex flex-column mb-2"> | <div class="d-flex flex-column mb-2"> | ||||
| <small>AntiPrompts</small> | <small>AntiPrompts</small> | ||||
| @Html.TextBoxFor(m => Model.SessionOptions.AntiPrompt, new { @type="text", @class = "form-control prompt-control"}) | |||||
| @Html.TextBoxFor(m => Model.SessionConfig.AntiPrompt, new { @type="text", @class = "form-control prompt-control"}) | |||||
| </div> | </div> | ||||
| <div class="d-flex flex-column mb-2"> | <div class="d-flex flex-column mb-2"> | ||||
| <small>OutputFilter</small> | <small>OutputFilter</small> | ||||
| @Html.TextBoxFor(m => Model.SessionOptions.OutputFilter, new { @type="text", @class = "form-control prompt-control"}) | |||||
| @Html.TextBoxFor(m => Model.SessionConfig.OutputFilter, new { @type="text", @class = "form-control prompt-control"}) | |||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| <div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab"> | <div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab"> | ||||
| @@ -18,14 +18,14 @@ namespace LLama.Web.Pages | |||||
| public LLamaOptions Options { get; set; } | public LLamaOptions Options { get; set; } | ||||
| [BindProperty] | [BindProperty] | ||||
| public Common.SessionOptions SessionOptions { get; set; } | |||||
| public ISessionConfig SessionConfig { get; set; } | |||||
| [BindProperty] | [BindProperty] | ||||
| public InferenceOptions InferenceOptions { get; set; } | public InferenceOptions InferenceOptions { get; set; } | ||||
| public void OnGet() | public void OnGet() | ||||
| { | { | ||||
| SessionOptions = new Common.SessionOptions | |||||
| SessionConfig = new SessionConfig | |||||
| { | { | ||||
| Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", | Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", | ||||
| AntiPrompt = "User:", | AntiPrompt = "User:", | ||||
| @@ -24,7 +24,7 @@ namespace LLama.Web.Services | |||||
| /// Creates a new ModelSession | /// Creates a new ModelSession | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="sessionId">The session identifier.</param> | /// <param name="sessionId">The session identifier.</param> | ||||
| /// <param name="sessionOptions">The session configuration.</param> | |||||
| /// <param name="sessionConfig">The session configuration.</param> | |||||
| /// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param> | /// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param> | ||||
| /// <param name="cancellationToken">The cancellation token.</param> | /// <param name="cancellationToken">The cancellation token.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| @@ -33,7 +33,7 @@ namespace LLama.Web.Services | |||||
| /// or | /// or | ||||
| /// Failed to create model session | /// Failed to create model session | ||||
| /// </exception> | /// </exception> | ||||
| Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); | |||||
| Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); | |||||
| /// <summary> | /// <summary> | ||||
| @@ -65,7 +65,7 @@ namespace LLama.Web.Services | |||||
| /// or | /// or | ||||
| /// Failed to create model session | /// Failed to create model session | ||||
| /// </exception> | /// </exception> | ||||
| public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | |||||
| public async Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | |||||
| { | { | ||||
| if (_modelSessions.TryGetValue(sessionId, out _)) | if (_modelSessions.TryGetValue(sessionId, out _)) | ||||
| throw new Exception($"Session with id {sessionId} already exists"); | throw new Exception($"Session with id {sessionId} already exists"); | ||||