| @@ -45,7 +45,8 @@ namespace LLama.Web.Hubs | |||||
| var modelOption = _options.Models.First(x => x.Name == modelName); | var modelOption = _options.Models.First(x => x.Name == modelName); | ||||
| var promptOption = _options.Prompts.First(x => x.Name == promptName); | var promptOption = _options.Prompts.First(x => x.Name == promptName); | ||||
| var parameterOption = _options.Parameters.First(x => x.Name == parameterName); | var parameterOption = _options.Parameters.First(x => x.Name == parameterName); | ||||
| var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, modelOption, promptOption, parameterOption); | |||||
| var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption)); | |||||
| var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption); | |||||
| if (modelSession is null) | if (modelSession is null) | ||||
| { | { | ||||
| _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId); | _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId); | ||||
| @@ -72,15 +73,15 @@ namespace LLama.Web.Hubs | |||||
| } | } | ||||
| // Create unique response id | // Create unique response id | ||||
| modelSession.ResponseId = Guid.NewGuid().ToString(); | |||||
| var responseId = Guid.NewGuid().ToString(); | |||||
| // Send begin of response | // Send begin of response | ||||
| await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, isFirst: true)); | |||||
| await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); | |||||
| // Send content of response | // Send content of response | ||||
| await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) | await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) | ||||
| { | { | ||||
| await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, fragment)); | |||||
| await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); | |||||
| } | } | ||||
| // Send end of response | // Send end of response | ||||
| @@ -88,7 +89,7 @@ namespace LLama.Web.Hubs | |||||
| var signature = modelSession.IsInferCanceled() | var signature = modelSession.IsInferCanceled() | ||||
| ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" | ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" | ||||
| : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; | : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; | ||||
| await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, signature, isLast: true)); | |||||
| await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); | |||||
| _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); | _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); | ||||
| } | } | ||||
| @@ -9,24 +9,23 @@ namespace LLama.Web.Models | |||||
| private PromptOptions _promptOptions; | private PromptOptions _promptOptions; | ||||
| private ParameterOptions _inferenceOptions; | private ParameterOptions _inferenceOptions; | ||||
| private ITextStreamTransform _outputTransform; | private ITextStreamTransform _outputTransform; | ||||
| private InteractiveExecutor _interactiveExecutor; | |||||
| private ILLamaExecutor _executor; | |||||
| private CancellationTokenSource _cancellationTokenSource; | private CancellationTokenSource _cancellationTokenSource; | ||||
| public ModelSession(string connectionId, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) | |||||
| public ModelSession(string connectionId, ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) | |||||
| { | { | ||||
| ConnectionId = connectionId; | ConnectionId = connectionId; | ||||
| _executor = executor; | |||||
| _modelOptions = modelOptions; | _modelOptions = modelOptions; | ||||
| _promptOptions = promptOptions; | _promptOptions = promptOptions; | ||||
| _inferenceOptions = parameterOptions; | _inferenceOptions = parameterOptions; | ||||
| _interactiveExecutor = new InteractiveExecutor(new LLamaModel(_modelOptions)); | |||||
| _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty<string>()).Distinct() ?? _inferenceOptions.AntiPrompts; | _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty<string>()).Distinct() ?? _inferenceOptions.AntiPrompts; | ||||
| if (_promptOptions.OutputFilter?.Count > 0) | if (_promptOptions.OutputFilter?.Count > 0) | ||||
| _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); | _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); | ||||
| } | } | ||||
| public string ConnectionId { get; } | public string ConnectionId { get; } | ||||
| public string ResponseId { get; set; } | |||||
| public IAsyncEnumerable<string> InferAsync(string message, CancellationTokenSource cancellationTokenSource) | public IAsyncEnumerable<string> InferAsync(string message, CancellationTokenSource cancellationTokenSource) | ||||
| { | { | ||||
| @@ -38,9 +37,9 @@ namespace LLama.Web.Models | |||||
| } | } | ||||
| if (_outputTransform is not null) | if (_outputTransform is not null) | ||||
| return _outputTransform.TransformAsync(_interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); | |||||
| return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); | |||||
| return _interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); | |||||
| return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); | |||||
| } | } | ||||
| @@ -58,8 +57,8 @@ namespace LLama.Web.Models | |||||
| { | { | ||||
| _inferenceOptions = null; | _inferenceOptions = null; | ||||
| _outputTransform = null; | _outputTransform = null; | ||||
| _interactiveExecutor.Model?.Dispose(); | |||||
| _interactiveExecutor = null; | |||||
| _executor.Model?.Dispose(); | |||||
| _executor = null; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -205,7 +205,7 @@ | |||||
| responseContainer = $(`#${response.id}`); | responseContainer = $(`#${response.id}`); | ||||
| responseContent = responseContainer.find(".content"); | responseContent = responseContainer.find(".content"); | ||||
| responseFirstFragment = true; | responseFirstFragment = true; | ||||
| scrollToBottom(); | |||||
| scrollToBottom(true); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -233,6 +233,7 @@ | |||||
| outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); | outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); | ||||
| connection.invoke('SendPrompt', text); | connection.invoke('SendPrompt', text); | ||||
| $("#input").val(null); | $("#input").val(null); | ||||
| scrollToBottom(true); | |||||
| } | } | ||||
| } | } | ||||
| @@ -309,9 +310,13 @@ | |||||
| } | } | ||||
| const scrollToBottom = () => { | |||||
| const scrollToBottom = (force) => { | |||||
| const scrollTop = scrollContainer.scrollTop(); | const scrollTop = scrollContainer.scrollTop(); | ||||
| const scrollHeight = scrollContainer[0].scrollHeight; | const scrollHeight = scrollContainer[0].scrollHeight; | ||||
| if(force){ | |||||
| scrollContainer.scrollTop(scrollContainer[0].scrollHeight); | |||||
| return; | |||||
| } | |||||
| if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) { | if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) { | ||||
| scrollContainer.scrollTop(scrollContainer[0].scrollHeight) | scrollContainer.scrollTop(scrollContainer[0].scrollHeight) | ||||
| } | } | ||||
| @@ -1,11 +1,12 @@ | |||||
| using LLama.Web.Models; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Web.Models; | |||||
| namespace LLama.Web.Services | namespace LLama.Web.Services | ||||
| { | { | ||||
| public interface IModelSessionService | public interface IModelSessionService | ||||
| { | { | ||||
| Task<ModelSession> GetAsync(string connectionId); | Task<ModelSession> GetAsync(string connectionId); | ||||
| Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); | |||||
| Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); | |||||
| Task RemoveAsync(string connectionId); | Task RemoveAsync(string connectionId); | ||||
| Task CancelAsync(string connectionId); | Task CancelAsync(string connectionId); | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using LLama.Web.Models; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Web.Models; | |||||
| using System.Collections.Concurrent; | using System.Collections.Concurrent; | ||||
| namespace LLama.Web.Services | namespace LLama.Web.Services | ||||
| @@ -20,10 +21,10 @@ namespace LLama.Web.Services | |||||
| return Task.FromResult(modelSession); | return Task.FromResult(modelSession); | ||||
| } | } | ||||
| public Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) | |||||
| public Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) | |||||
| { | { | ||||
| //TODO: Max instance etc | //TODO: Max instance etc | ||||
| var modelSession = new ModelSession(connectionId, modelOption, promptOption, parameterOption); | |||||
| var modelSession = new ModelSession( connectionId, executor, modelOption, promptOption, parameterOption); | |||||
| if (!_modelSessions.TryAdd(connectionId, modelSession)) | if (!_modelSessions.TryAdd(connectionId, modelSession)) | ||||
| { | { | ||||
| _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId); | _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId); | ||||