| @@ -0,0 +1,13 @@ | |||
| namespace LLama.Web.Common | |||
| { | |||
| public interface ISessionConfig | |||
| { | |||
| string AntiPrompt { get; set; } | |||
| List<string> AntiPrompts { get; set; } | |||
| LLamaExecutorType ExecutorType { get; set; } | |||
| string Model { get; set; } | |||
| string OutputFilter { get; set; } | |||
| List<string> OutputFilters { get; set; } | |||
| string Prompt { get; set; } | |||
| } | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| namespace LLama.Web.Common | |||
| { | |||
| public class SessionOptions | |||
| public class SessionConfig : ISessionConfig | |||
| { | |||
| public string Model { get; set; } | |||
| public string Prompt { get; set; } | |||
| @@ -2,14 +2,14 @@ | |||
| namespace LLama.Web | |||
| { | |||
| public static class Extensioms | |||
| public static class Extensions | |||
| { | |||
| /// <summary> | |||
| /// Combines the AntiPrompts list and AntiPrompt csv | |||
| /// </summary> | |||
| /// <param name="sessionConfig">The session configuration.</param> | |||
| /// <returns>Combined AntiPrompts with duplicates removed</returns> | |||
| public static List<string> GetAntiPrompts(this Common.SessionOptions sessionConfig) | |||
| public static List<string> GetAntiPrompts(this ISessionConfig sessionConfig) | |||
| { | |||
| return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); | |||
| } | |||
| @@ -19,7 +19,7 @@ namespace LLama.Web | |||
| /// </summary> | |||
| /// <param name="sessionConfig">The session configuration.</param> | |||
| /// <returns>Combined OutputFilters with duplicates removed</returns> | |||
| public static List<string> GetOutputFilters(this Common.SessionOptions sessionConfig) | |||
| public static List<string> GetOutputFilters(this ISessionConfig sessionConfig) | |||
| { | |||
| return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); | |||
| } | |||
| @@ -37,7 +37,7 @@ namespace LLama.Web.Hubs | |||
| [HubMethodName("LoadModel")] | |||
| public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig) | |||
| public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig) | |||
| { | |||
| _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); | |||
| await _modelSessionService.CloseAsync(Context.ConnectionId); | |||
| @@ -9,21 +9,21 @@ namespace LLama.Web.Models | |||
| private readonly LLamaModel _model; | |||
| private readonly LLamaContext _context; | |||
| private readonly ILLamaExecutor _executor; | |||
| private readonly Common.SessionOptions _sessionParams; | |||
| private readonly ISessionConfig _sessionConfig; | |||
| private readonly ITextStreamTransform _outputTransform; | |||
| private readonly InferenceOptions _defaultInferenceConfig; | |||
| private CancellationTokenSource _cancellationTokenSource; | |||
| public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) | |||
| public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null) | |||
| { | |||
| _model = model; | |||
| _context = context; | |||
| _sessionId = sessionId; | |||
| _sessionParams = sessionOptions; | |||
| _sessionConfig = sessionConfig; | |||
| _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); | |||
| _outputTransform = CreateOutputFilter(_sessionParams); | |||
| _executor = CreateExecutor(_model, _context, _sessionParams); | |||
| _outputTransform = CreateOutputFilter(); | |||
| _executor = CreateExecutor(); | |||
| } | |||
| /// <summary> | |||
| @@ -34,7 +34,7 @@ namespace LLama.Web.Models | |||
| /// <summary> | |||
| /// Gets the name of the model. | |||
| /// </summary> | |||
| public string ModelName => _sessionParams.Model; | |||
| public string ModelName => _sessionConfig.Model; | |||
| /// <summary> | |||
| /// Gets the context. | |||
| @@ -44,7 +44,7 @@ namespace LLama.Web.Models | |||
| /// <summary> | |||
| /// Gets the session configuration. | |||
| /// </summary> | |||
| public Common.SessionOptions SessionConfig => _sessionParams; | |||
| public ISessionConfig SessionConfig => _sessionConfig; | |||
| /// <summary> | |||
| /// Gets the inference parameters. | |||
| @@ -60,16 +60,16 @@ namespace LLama.Web.Models | |||
| /// <param name="cancellationToken">The cancellation token.</param> | |||
| internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | |||
| { | |||
| if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) | |||
| if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) | |||
| return; | |||
| if (string.IsNullOrEmpty(_sessionParams.Prompt)) | |||
| if (string.IsNullOrEmpty(_sessionConfig.Prompt)) | |||
| return; | |||
| // Run Initial prompt | |||
| var inferenceParams = ConfigureInferenceParams(inferenceConfig); | |||
| _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); | |||
| await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) | |||
| await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token)) | |||
| { | |||
| // We dont really need the response of the initial prompt, so exit on first token | |||
| break; | |||
| @@ -114,13 +114,13 @@ namespace LLama.Web.Models | |||
| private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) | |||
| { | |||
| var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; | |||
| inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); | |||
| inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts(); | |||
| return inferenceParams; | |||
| } | |||
| private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) | |||
| private ITextStreamTransform CreateOutputFilter() | |||
| { | |||
| var outputFilters = sessionConfig.GetOutputFilters(); | |||
| var outputFilters = _sessionConfig.GetOutputFilters(); | |||
| if (outputFilters.Count > 0) | |||
| return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); | |||
| @@ -128,9 +128,9 @@ namespace LLama.Web.Models | |||
| } | |||
| private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) | |||
| private ILLamaExecutor CreateExecutor() | |||
| { | |||
| return sessionConfig.ExecutorType switch | |||
| return _sessionConfig.ExecutorType switch | |||
| { | |||
| LLamaExecutorType.Interactive => new InteractiveExecutor(_context), | |||
| LLamaExecutorType.Instruct => new InstructExecutor(_context), | |||
| @@ -24,11 +24,11 @@ | |||
| <div class="d-flex flex-column m-1"> | |||
| <div class="d-flex flex-column mb-2"> | |||
| <small>Model</small> | |||
| @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||
| @Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||
| </div> | |||
| <div class="d-flex flex-column mb-2"> | |||
| <small>Inference Type</small> | |||
| @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||
| @Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) | |||
| </div> | |||
| <nav> | |||
| <div class="nav nav-tabs" id="nav-tab" role="tablist"> | |||
| @@ -40,17 +40,17 @@ | |||
| <div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab"> | |||
| <div class="d-flex flex-column mb-2"> | |||
| <small>Prompt</small> | |||
| @Html.TextAreaFor(m => Model.SessionOptions.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8}) | |||
| @Html.TextAreaFor(m => Model.SessionConfig.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8}) | |||
| </div> | |||
| <div class="d-flex flex-column mb-2"> | |||
| <small>AntiPrompts</small> | |||
| @Html.TextBoxFor(m => Model.SessionOptions.AntiPrompt, new { @type="text", @class = "form-control prompt-control"}) | |||
| @Html.TextBoxFor(m => Model.SessionConfig.AntiPrompt, new { @type="text", @class = "form-control prompt-control"}) | |||
| </div> | |||
| <div class="d-flex flex-column mb-2"> | |||
| <small>OutputFilter</small> | |||
| @Html.TextBoxFor(m => Model.SessionOptions.OutputFilter, new { @type="text", @class = "form-control prompt-control"}) | |||
| @Html.TextBoxFor(m => Model.SessionConfig.OutputFilter, new { @type="text", @class = "form-control prompt-control"}) | |||
| </div> | |||
| </div> | |||
| <div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab"> | |||
| @@ -18,14 +18,14 @@ namespace LLama.Web.Pages | |||
| public LLamaOptions Options { get; set; } | |||
| [BindProperty] | |||
| public Common.SessionOptions SessionOptions { get; set; } | |||
| public ISessionConfig SessionConfig { get; set; } | |||
| [BindProperty] | |||
| public InferenceOptions InferenceOptions { get; set; } | |||
| public void OnGet() | |||
| { | |||
| SessionOptions = new Common.SessionOptions | |||
| SessionConfig = new SessionConfig | |||
| { | |||
| Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", | |||
| AntiPrompt = "User:", | |||
| @@ -24,7 +24,7 @@ namespace LLama.Web.Services | |||
| /// Creates a new ModelSession | |||
| /// </summary> | |||
| /// <param name="sessionId">The session identifier.</param> | |||
| /// <param name="sessionOptions">The session configuration.</param> | |||
| /// <param name="sessionConfig">The session configuration.</param> | |||
| /// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param> | |||
| /// <param name="cancellationToken">The cancellation token.</param> | |||
| /// <returns></returns> | |||
| @@ -33,7 +33,7 @@ namespace LLama.Web.Services | |||
| /// or | |||
| /// Failed to create model session | |||
| /// </exception> | |||
| Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); | |||
| Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); | |||
| /// <summary> | |||
| @@ -65,7 +65,7 @@ namespace LLama.Web.Services | |||
| /// or | |||
| /// Failed to create model session | |||
| /// </exception> | |||
| public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | |||
| public async Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) | |||
| { | |||
| if (_modelSessions.TryGetValue(sessionId, out _)) | |||
| throw new Exception($"Session with id {sessionId} already exists"); | |||