You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionConnectionHub.cs 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. using LLama.Web.Common;
  2. using LLama.Web.Models;
  3. using LLama.Web.Services;
  4. using Microsoft.AspNetCore.SignalR;
  5. namespace LLama.Web.Hubs
  6. {
  7. public class SessionConnectionHub : Hub<ISessionClient>
  8. {
  9. private readonly ILogger<SessionConnectionHub> _logger;
  10. private readonly IModelSessionService _modelSessionService;
  11. public SessionConnectionHub(ILogger<SessionConnectionHub> logger, IModelSessionService modelSessionService)
  12. {
  13. _logger = logger;
  14. _modelSessionService = modelSessionService;
  15. }
  16. public override async Task OnConnectedAsync()
  17. {
  18. _logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId);
  19. // Notify client of successful connection
  20. await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Connected);
  21. await base.OnConnectedAsync();
  22. }
  23. public override async Task OnDisconnectedAsync(Exception exception)
  24. {
  25. _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId);
  26. // Remove connections session on dissconnect
  27. await _modelSessionService.CloseAsync(Context.ConnectionId);
  28. await base.OnDisconnectedAsync(exception);
  29. }
  30. [HubMethodName("LoadModel")]
  31. public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions inferenceConfig)
  32. {
  33. _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
  34. await _modelSessionService.CloseAsync(Context.ConnectionId);
  35. // Create model session
  36. var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig);
  37. if (modelSession is null)
  38. {
  39. await Clients.Caller.OnError("Failed to create model session");
  40. return;
  41. }
  42. // Notify client
  43. await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Loaded);
  44. }
  45. [HubMethodName("SendPrompt")]
  46. public IAsyncEnumerable<TokenModel> OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken)
  47. {
  48. _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId);
  49. var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken);
  50. return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token);
  51. }
  52. }
  53. }