You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConnectionSessionService.cs 4.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. using LLama.Abstractions;
  2. using LLama.Web.Common;
  3. using LLama.Web.Models;
  4. using Microsoft.Extensions.Options;
  5. using System.Collections.Concurrent;
  6. using System.Drawing;
  7. namespace LLama.Web.Services
  8. {
  9. /// <summary>
  10. /// Example Service for handling a model session for a websockets connection lifetime
  11. /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc
  12. /// </summary>
  13. public class ConnectionSessionService : IModelSessionService
  14. {
  15. private readonly LLamaOptions _options;
  16. private readonly ILogger<ConnectionSessionService> _logger;
  17. private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
  18. public ConnectionSessionService(ILogger<ConnectionSessionService> logger, IOptions<LLamaOptions> options)
  19. {
  20. _logger = logger;
  21. _options = options.Value;
  22. _modelSessions = new ConcurrentDictionary<string, ModelSession>();
  23. }
  24. public Task<ModelSession> GetAsync(string connectionId)
  25. {
  26. _modelSessions.TryGetValue(connectionId, out var modelSession);
  27. return Task.FromResult(modelSession);
  28. }
  29. public Task<IServiceResult<ModelSession>> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName)
  30. {
  31. var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName);
  32. if (modelOption is null)
  33. return Task.FromResult(ServiceResult.FromError<ModelSession>($"Model option '{modelName}' not found"));
  34. var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName);
  35. if (promptOption is null)
  36. return Task.FromResult(ServiceResult.FromError<ModelSession>($"Prompt option '{promptName}' not found"));
  37. var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName);
  38. if (parameterOption is null)
  39. return Task.FromResult(ServiceResult.FromError<ModelSession>($"Parameter option '{parameterName}' not found"));
  40. //Max instance
  41. var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name);
  42. if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances)
  43. return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached"));
  44. // Create model
  45. var llamaModel = new LLamaModel(modelOption);
  46. // Create executor
  47. ILLamaExecutor executor = executorType switch
  48. {
  49. LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel),
  50. LLamaExecutorType.Instruct => new InstructExecutor(llamaModel),
  51. LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel),
  52. _ => default
  53. };
  54. // Create session
  55. var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption);
  56. if (!_modelSessions.TryAdd(connectionId, modelSession))
  57. return Task.FromResult(ServiceResult.FromError<ModelSession>("Failed to create model session"));
  58. return Task.FromResult(ServiceResult.FromValue(modelSession));
  59. }
  60. public Task<bool> RemoveAsync(string connectionId)
  61. {
  62. if (_modelSessions.TryRemove(connectionId, out var modelSession))
  63. {
  64. modelSession.CancelInfer();
  65. modelSession.Dispose();
  66. return Task.FromResult(true);
  67. }
  68. return Task.FromResult(false);
  69. }
  70. public Task<bool> CancelAsync(string connectionId)
  71. {
  72. if (_modelSessions.TryGetValue(connectionId, out var modelSession))
  73. {
  74. modelSession.CancelInfer();
  75. return Task.FromResult(true);
  76. }
  77. return Task.FromResult(false);
  78. }
  79. }
  80. }