Browse Source

Move session management to service, Use ILLamaExecutor in session to make more versatile, scroll bug

tags/v0.4.2-preview
sa_ddam213 2 years ago
parent
commit
a139423581
5 changed files with 27 additions and 20 deletions
  1. +6
    -5
      LLama.Web/Hubs/InteractiveHub.cs
  2. +7
    -8
      LLama.Web/Models/ModelSession.cs
  3. +7
    -2
      LLama.Web/Pages/Interactive.cshtml
  4. +3
    -2
      LLama.Web/Services/IModelSessionService.cs
  5. +4
    -3
      LLama.Web/Services/ModelSessionService.cs

+ 6
- 5
LLama.Web/Hubs/InteractiveHub.cs View File

@@ -45,7 +45,8 @@ namespace LLama.Web.Hubs
var modelOption = _options.Models.First(x => x.Name == modelName); var modelOption = _options.Models.First(x => x.Name == modelName);
var promptOption = _options.Prompts.First(x => x.Name == promptName); var promptOption = _options.Prompts.First(x => x.Name == promptName);
var parameterOption = _options.Parameters.First(x => x.Name == parameterName); var parameterOption = _options.Parameters.First(x => x.Name == parameterName);
var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, modelOption, promptOption, parameterOption);
var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption));
var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption);
if (modelSession is null) if (modelSession is null)
{ {
_logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId); _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId);
@@ -72,15 +73,15 @@ namespace LLama.Web.Hubs
} }


// Create unique response id // Create unique response id
modelSession.ResponseId = Guid.NewGuid().ToString();
var responseId = Guid.NewGuid().ToString();


// Send begin of response // Send begin of response
await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, isFirst: true));
await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true));


// Send content of response // Send content of response
await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted)))
{ {
await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, fragment));
await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment));
} }


// Send end of response // Send end of response
@@ -88,7 +89,7 @@ namespace LLama.Web.Hubs
var signature = modelSession.IsInferCanceled() var signature = modelSession.IsInferCanceled()
? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds"
: $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds";
await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, signature, isLast: true));
await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true));
_logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled());
} }


+ 7
- 8
LLama.Web/Models/ModelSession.cs View File

@@ -9,24 +9,23 @@ namespace LLama.Web.Models
private PromptOptions _promptOptions; private PromptOptions _promptOptions;
private ParameterOptions _inferenceOptions; private ParameterOptions _inferenceOptions;
private ITextStreamTransform _outputTransform; private ITextStreamTransform _outputTransform;
private InteractiveExecutor _interactiveExecutor;
private ILLamaExecutor _executor;
private CancellationTokenSource _cancellationTokenSource; private CancellationTokenSource _cancellationTokenSource;


public ModelSession(string connectionId, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions)
public ModelSession(string connectionId, ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions)
{ {
ConnectionId = connectionId; ConnectionId = connectionId;
_executor = executor;
_modelOptions = modelOptions; _modelOptions = modelOptions;
_promptOptions = promptOptions; _promptOptions = promptOptions;
_inferenceOptions = parameterOptions; _inferenceOptions = parameterOptions;
_interactiveExecutor = new InteractiveExecutor(new LLamaModel(_modelOptions));
_inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty<string>()).Distinct() ?? _inferenceOptions.AntiPrompts; _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty<string>()).Distinct() ?? _inferenceOptions.AntiPrompts;
if (_promptOptions.OutputFilter?.Count > 0) if (_promptOptions.OutputFilter?.Count > 0)
_outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5);
} }


public string ConnectionId { get; } public string ConnectionId { get; }
public string ResponseId { get; set; }


public IAsyncEnumerable<string> InferAsync(string message, CancellationTokenSource cancellationTokenSource) public IAsyncEnumerable<string> InferAsync(string message, CancellationTokenSource cancellationTokenSource)
{ {
@@ -38,9 +37,9 @@ namespace LLama.Web.Models
} }


if (_outputTransform is not null) if (_outputTransform is not null)
return _outputTransform.TransformAsync(_interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token));
return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token));


return _interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token);
return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token);
} }




@@ -58,8 +57,8 @@ namespace LLama.Web.Models
{ {
_inferenceOptions = null; _inferenceOptions = null;
_outputTransform = null; _outputTransform = null;
_interactiveExecutor.Model?.Dispose();
_interactiveExecutor = null;
_executor.Model?.Dispose();
_executor = null;
} }
} }
} }

+ 7
- 2
LLama.Web/Pages/Interactive.cshtml View File

@@ -205,7 +205,7 @@
responseContainer = $(`#${response.id}`); responseContainer = $(`#${response.id}`);
responseContent = responseContainer.find(".content"); responseContent = responseContainer.find(".content");
responseFirstFragment = true; responseFirstFragment = true;
scrollToBottom();
scrollToBottom(true);
return; return;
} }


@@ -233,6 +233,7 @@
outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() }));
connection.invoke('SendPrompt', text); connection.invoke('SendPrompt', text);
$("#input").val(null); $("#input").val(null);
scrollToBottom(true);
} }
} }


@@ -309,9 +310,13 @@
} }




const scrollToBottom = () => {
const scrollToBottom = (force) => {
const scrollTop = scrollContainer.scrollTop(); const scrollTop = scrollContainer.scrollTop();
const scrollHeight = scrollContainer[0].scrollHeight; const scrollHeight = scrollContainer[0].scrollHeight;
if(force){
scrollContainer.scrollTop(scrollContainer[0].scrollHeight);
return;
}
if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) { if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) {
scrollContainer.scrollTop(scrollContainer[0].scrollHeight) scrollContainer.scrollTop(scrollContainer[0].scrollHeight)
} }


+ 3
- 2
LLama.Web/Services/IModelSessionService.cs View File

@@ -1,11 +1,12 @@
using LLama.Web.Models;
using LLama.Abstractions;
using LLama.Web.Models;


namespace LLama.Web.Services namespace LLama.Web.Services
{ {
public interface IModelSessionService public interface IModelSessionService
{ {
Task<ModelSession> GetAsync(string connectionId); Task<ModelSession> GetAsync(string connectionId);
Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption);
Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption);
Task RemoveAsync(string connectionId); Task RemoveAsync(string connectionId);
Task CancelAsync(string connectionId); Task CancelAsync(string connectionId);
} }


+ 4
- 3
LLama.Web/Services/ModelSessionService.cs View File

@@ -1,4 +1,5 @@
using LLama.Web.Models;
using LLama.Abstractions;
using LLama.Web.Models;
using System.Collections.Concurrent; using System.Collections.Concurrent;


namespace LLama.Web.Services namespace LLama.Web.Services
@@ -20,10 +21,10 @@ namespace LLama.Web.Services
return Task.FromResult(modelSession); return Task.FromResult(modelSession);
} }


public Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption)
public Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption)
{ {
//TODO: Max instance etc //TODO: Max instance etc
var modelSession = new ModelSession(connectionId, modelOption, promptOption, parameterOption);
var modelSession = new ModelSession( connectionId, executor, modelOption, promptOption, parameterOption);
if (!_modelSessions.TryAdd(connectionId, modelSession)) if (!_modelSessions.TryAdd(connectionId, modelSession))
{ {
_logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId); _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId);


Loading…
Cancel
Save