You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSessionService.cs 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. using LLama.Abstractions;
  2. using LLama.Web.Models;
  3. using System.Collections.Concurrent;
  4. namespace LLama.Web.Services
  5. {
  6. public class ModelSessionService : IModelSessionService
  7. {
  8. private readonly ILogger<ModelSessionService> _logger;
  9. private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
  10. public ModelSessionService(ILogger<ModelSessionService> logger)
  11. {
  12. _logger = logger;
  13. _modelSessions = new ConcurrentDictionary<string, ModelSession>();
  14. }
  15. public Task<ModelSession> GetAsync(string connectionId)
  16. {
  17. _modelSessions.TryGetValue(connectionId, out var modelSession);
  18. return Task.FromResult(modelSession);
  19. }
  20. public Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption)
  21. {
  22. //TODO: Max instance etc
  23. var modelSession = new ModelSession( connectionId, executor, modelOption, promptOption, parameterOption);
  24. if (!_modelSessions.TryAdd(connectionId, modelSession))
  25. {
  26. _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId);
  27. return Task.FromResult<ModelSession>(default);
  28. }
  29. return Task.FromResult(modelSession);
  30. }
  31. public Task RemoveAsync(string connectionId)
  32. {
  33. if (_modelSessions.TryRemove(connectionId, out var modelSession))
  34. {
  35. _logger.Log(LogLevel.Information, "[RemoveAsync] - Removed model session, Connection: {0}", connectionId);
  36. modelSession.Dispose();
  37. }
  38. return Task.CompletedTask;
  39. }
  40. public Task CancelAsync(string connectionId)
  41. {
  42. if (_modelSessions.TryGetValue(connectionId, out var modelSession))
  43. {
  44. _logger.Log(LogLevel.Information, "[CancelAsync] - Canceled model session, Connection: {0}", connectionId);
  45. modelSession.CancelInfer();
  46. }
  47. return Task.CompletedTask;
  48. }
  49. }
  50. }