You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelService.cs 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. using LLama.Web.Async;
  2. using LLama.Web.Common;
  3. using LLama.Web.Models;
  4. using Microsoft.Extensions.Options;
  5. using System.Collections.Concurrent;
  6. namespace LLama.Web.Services
  7. {
  8. /// <summary>
  9. /// Sercive for handling Models,Weights & Contexts
  10. /// </summary>
  11. public class ModelService : IModelService
  12. {
  13. private readonly AsyncLock _modelLock;
  14. private readonly AsyncLock _contextLock;
  15. private readonly LLamaOptions _configuration;
  16. private readonly ILogger<ModelService> _llamaLogger;
  17. private readonly ConcurrentDictionary<string, LLamaModel> _modelInstances;
  18. /// <summary>
  19. /// Initializes a new instance of the <see cref="ModelService"/> class.
  20. /// </summary>
  21. /// <param name="logger">The logger.</param>
  22. /// <param name="options">The options.</param>
  23. public ModelService(IOptions<LLamaOptions> configuration, ILogger<ModelService> llamaLogger)
  24. {
  25. _llamaLogger = llamaLogger;
  26. _modelLock = new AsyncLock();
  27. _contextLock = new AsyncLock();
  28. _configuration = configuration.Value;
  29. _modelInstances = new ConcurrentDictionary<string, LLamaModel>();
  30. }
  31. /// <summary>
  32. /// Loads a model with the provided configuration.
  33. /// </summary>
  34. /// <param name="modelOptions">The model configuration.</param>
  35. /// <returns></returns>
  36. public async Task<LLamaModel> LoadModel(ModelOptions modelOptions)
  37. {
  38. if (_modelInstances.TryGetValue(modelOptions.Name, out var existingModel))
  39. return existingModel;
  40. using (await _modelLock.LockAsync())
  41. {
  42. if (_modelInstances.TryGetValue(modelOptions.Name, out var model))
  43. return model;
  44. // If in single mode unload any other models
  45. if (_configuration.ModelLoadType == ModelLoadType.Single
  46. || _configuration.ModelLoadType == ModelLoadType.PreloadSingle)
  47. await UnloadModels();
  48. model = new LLamaModel(modelOptions, _llamaLogger);
  49. _modelInstances.TryAdd(modelOptions.Name, model);
  50. return model;
  51. }
  52. }
  53. /// <summary>
  54. /// Loads the models.
  55. /// </summary>
  56. public async Task LoadModels()
  57. {
  58. if (_configuration.ModelLoadType == ModelLoadType.Single
  59. || _configuration.ModelLoadType == ModelLoadType.Multiple)
  60. return;
  61. foreach (var modelConfig in _configuration.Models)
  62. {
  63. await LoadModel(modelConfig);
  64. //Only preload first model if in SinglePreload mode
  65. if (_configuration.ModelLoadType == ModelLoadType.PreloadSingle)
  66. break;
  67. }
  68. }
  69. /// <summary>
  70. /// Unloads the model.
  71. /// </summary>
  72. /// <param name="modelName">Name of the model.</param>
  73. /// <returns></returns>
  74. public Task UnloadModel(string modelName)
  75. {
  76. if (_modelInstances.TryRemove(modelName, out var model))
  77. {
  78. model?.Dispose();
  79. return Task.FromResult(true);
  80. }
  81. return Task.FromResult(false);
  82. }
  83. /// <summary>
  84. /// Unloads all models.
  85. /// </summary>
  86. public async Task UnloadModels()
  87. {
  88. foreach (var modelName in _modelInstances.Keys)
  89. {
  90. await UnloadModel(modelName);
  91. }
  92. }
  93. /// <summary>
  94. /// Gets a model ny name.
  95. /// </summary>
  96. /// <param name="modelName">Name of the model.</param>
  97. /// <returns></returns>
  98. public Task<LLamaModel> GetModel(string modelName)
  99. {
  100. _modelInstances.TryGetValue(modelName, out var model);
  101. return Task.FromResult(model);
  102. }
  103. /// <summary>
  104. /// Gets a context from the specified model.
  105. /// </summary>
  106. /// <param name="modelName">Name of the model.</param>
  107. /// <param name="contextName">The contextName.</param>
  108. /// <returns></returns>
  109. /// <exception cref="System.Exception">Model not found</exception>
  110. public async Task<LLamaContext> GetContext(string modelName, string contextName)
  111. {
  112. if (!_modelInstances.TryGetValue(modelName, out var model))
  113. throw new Exception("Model not found");
  114. return await model.GetContext(contextName);
  115. }
  116. /// <summary>
  117. /// Creates a context on the specified model.
  118. /// </summary>
  119. /// <param name="modelName">Name of the model.</param>
  120. /// <param name="contextName">The contextName.</param>
  121. /// <returns></returns>
  122. /// <exception cref="System.Exception">Model not found</exception>
  123. public async Task<LLamaContext> CreateContext(string modelName, string contextName)
  124. {
  125. if (!_modelInstances.TryGetValue(modelName, out var model))
  126. throw new Exception("Model not found");
  127. using (await _contextLock.LockAsync())
  128. {
  129. return await model.CreateContext(contextName);
  130. }
  131. }
  132. /// <summary>
  133. /// Removes a context from the specified model.
  134. /// </summary>
  135. /// <param name="modelName">Name of the model.</param>
  136. /// <param name="contextName">The contextName.</param>
  137. /// <returns></returns>
  138. /// <exception cref="System.Exception">Model not found</exception>
  139. public async Task<bool> RemoveContext(string modelName, string contextName)
  140. {
  141. if (!_modelInstances.TryGetValue(modelName, out var model))
  142. throw new Exception("Model not found");
  143. using (await _contextLock.LockAsync())
  144. {
  145. return await model.RemoveContext(contextName);
  146. }
  147. }
  148. /// <summary>
  149. /// Loads, Gets,Creates a Model and a Context
  150. /// </summary>
  151. /// <param name="modelName">Name of the model.</param>
  152. /// <param name="contextName">The contextName.</param>
  153. /// <returns></returns>
  154. /// <exception cref="System.Exception">Model option '{modelName}' not found</exception>
  155. public async Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName)
  156. {
  157. if (_modelInstances.TryGetValue(modelName, out var model))
  158. return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName));
  159. // Get model configuration
  160. var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName);
  161. if (modelConfig is null)
  162. throw new Exception($"Model option '{modelName}' not found");
  163. // Load Model
  164. model = await LoadModel(modelConfig);
  165. // Get or Create Context
  166. return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName));
  167. }
  168. }
  169. }