You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionConnectionHub.cs 4.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. using LLama.Web.Common;
  2. using LLama.Web.Models;
  3. using LLama.Web.Services;
  4. using Microsoft.AspNetCore.SignalR;
  5. using System.Diagnostics;
  6. namespace LLama.Web.Hubs
  7. {
  8. public class SessionConnectionHub : Hub<ISessionClient>
  9. {
  10. private readonly ILogger<SessionConnectionHub> _logger;
  11. private readonly ConnectionSessionService _modelSessionService;
  12. public SessionConnectionHub(ILogger<SessionConnectionHub> logger, ConnectionSessionService modelSessionService)
  13. {
  14. _logger = logger;
  15. _modelSessionService = modelSessionService;
  16. }
  17. public override async Task OnConnectedAsync()
  18. {
  19. _logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId);
  20. // Notify client of successful connection
  21. await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Connected);
  22. await base.OnConnectedAsync();
  23. }
  24. public override async Task OnDisconnectedAsync(Exception? exception)
  25. {
  26. _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId);
  27. // Remove connections session on dissconnect
  28. await _modelSessionService.RemoveAsync(Context.ConnectionId);
  29. await base.OnDisconnectedAsync(exception);
  30. }
  31. [HubMethodName("LoadModel")]
  32. public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, string promptName, string parameterName)
  33. {
  34. _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName);
  35. // Remove existing connections session
  36. await _modelSessionService.RemoveAsync(Context.ConnectionId);
  37. // Create model session
  38. var modelSessionResult = await _modelSessionService.CreateAsync(executorType, Context.ConnectionId, modelName, promptName, parameterName);
  39. if (modelSessionResult.HasError)
  40. {
  41. await Clients.Caller.OnError(modelSessionResult.Error);
  42. return;
  43. }
  44. // Notify client
  45. await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Loaded);
  46. }
  47. [HubMethodName("SendPrompt")]
  48. public async Task OnSendPrompt(string prompt)
  49. {
  50. _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId);
  51. // Get connections session
  52. var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId);
  53. if (modelSession is null)
  54. {
  55. await Clients.Caller.OnError("No model has been loaded");
  56. return;
  57. }
  58. // Create unique response id
  59. var responseId = Guid.NewGuid().ToString();
  60. // Send begin of response
  61. await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true));
  62. // Send content of response
  63. var stopwatch = Stopwatch.GetTimestamp();
  64. await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted)))
  65. {
  66. await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment));
  67. }
  68. // Send end of response
  69. var elapsedTime = Stopwatch.GetElapsedTime(stopwatch);
  70. var signature = modelSession.IsInferCanceled()
  71. ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds"
  72. : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds";
  73. await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true));
  74. _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled());
  75. }
  76. }
  77. }