You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

InteractiveHub.cs 4.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. using LLama.Web.Models;
  2. using LLama.Web.Services;
  3. using Microsoft.AspNetCore.SignalR;
  4. using Microsoft.Extensions.Options;
  5. using System.Diagnostics;
  6. namespace LLama.Web.Hubs
  7. {
  8. public class InteractiveHub : Hub<ISessionClient>
  9. {
  10. private readonly LLamaOptions _options;
  11. private readonly ILogger<InteractiveHub> _logger;
  12. private readonly IModelSessionService _modelSessionService;
  13. public InteractiveHub(ILogger<InteractiveHub> logger, IOptions<LLamaOptions> options, IModelSessionService modelSessionService)
  14. {
  15. _logger = logger;
  16. _options = options.Value;
  17. _modelSessionService = modelSessionService;
  18. }
  19. public override async Task OnConnectedAsync()
  20. {
  21. _logger.Log(LogLevel.Information, "OnConnectedAsync, Id: {0}", Context.ConnectionId);
  22. await base.OnConnectedAsync();
  23. await Clients.Caller.OnStatus("Connected", Context.ConnectionId);
  24. }
  25. public override async Task OnDisconnectedAsync(Exception? exception)
  26. {
  27. _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId);
  28. await _modelSessionService.RemoveAsync(Context.ConnectionId);
  29. await base.OnDisconnectedAsync(exception);
  30. }
  31. [HubMethodName("LoadModel")]
  32. public async Task OnLoadModel(string modelName, string promptName, string parameterName)
  33. {
  34. _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName);
  35. await _modelSessionService.RemoveAsync(Context.ConnectionId);
  36. var modelOption = _options.Models.First(x => x.Name == modelName);
  37. var promptOption = _options.Prompts.First(x => x.Name == promptName);
  38. var parameterOption = _options.Parameters.First(x => x.Name == parameterName);
  39. var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption));
  40. var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption);
  41. if (modelSession is null)
  42. {
  43. _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId);
  44. await Clients.Caller.OnError("No model has been loaded");
  45. return;
  46. }
  47. _logger.Log(LogLevel.Information, "[OnLoadModel] - New model session added, Connection: {0}", Context.ConnectionId);
  48. await Clients.Caller.OnStatus("Loaded", Context.ConnectionId);
  49. }
  50. [HubMethodName("SendPrompt")]
  51. public async Task OnSendPrompt(string prompt)
  52. {
  53. var stopwatch = Stopwatch.GetTimestamp();
  54. _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId);
  55. var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId);
  56. if (modelSession is null)
  57. {
  58. _logger.Log(LogLevel.Warning, "[OnSendPrompt] - No model has been loaded for this connection, Connection: {0}", Context.ConnectionId);
  59. await Clients.Caller.OnError("No model has been loaded");
  60. return;
  61. }
  62. // Create unique response id
  63. var responseId = Guid.NewGuid().ToString();
  64. // Send begin of response
  65. await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true));
  66. // Send content of response
  67. await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted)))
  68. {
  69. await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment));
  70. }
  71. // Send end of response
  72. var elapsedTime = Stopwatch.GetElapsedTime(stopwatch);
  73. var signature = modelSession.IsInferCanceled()
  74. ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds"
  75. : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds";
  76. await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true));
  77. _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled());
  78. }
  79. }
  80. }